forked from mindspore-Ecosystem/mindspore
!22979 ResNext18模型增加ModelArts训练、SDK和MxBase推理功能
Merge pull request !22979 from Atlas_ymc/master
This commit is contained in:
commit
c7a60b2e4c
|
@ -18,6 +18,7 @@
|
|||
# Modelzoo
|
||||
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/Yolov4TinyDetection.h" "runtime/references"
|
||||
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/PostProcess/Yolov4TinyMindsporePost.h" "runtime/references"
|
||||
"mindspore/model_zoo/official/cv/resnet/infer/ResNet18/mxbase/Resnet18ClassifyOpencv.h" "runtime/references"
|
||||
|
||||
# MindData
|
||||
"mindspore/mindspore/ccsrc/minddata/mindrecord/include/shard_page.h" "runtime/string"
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
aipp_op {
|
||||
aipp_mode: static
|
||||
input_format : RGB888_U8
|
||||
|
||||
rbuv_swap_switch : true
|
||||
|
||||
mean_chn_0 : 0
|
||||
mean_chn_1 : 0
|
||||
mean_chn_2 : 0
|
||||
min_chn_0 : 123.675
|
||||
min_chn_1 : 116.28
|
||||
min_chn_2 : 103.53
|
||||
var_reci_chn_0 : 0.0171247538316637
|
||||
var_reci_chn_1 : 0.0175070028011204
|
||||
var_reci_chn_2 : 0.0174291938997821
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
air_path=$1
|
||||
aipp_cfg_path=$2
|
||||
om_path=$3
|
||||
|
||||
/usr/local/Ascend/atc/bin/atc \
|
||||
--model="$air_path" \
|
||||
--framework=1 \
|
||||
--output="$om_path" \
|
||||
--input_format=NCHW --input_shape="actual_input_1:1,3,304,304" \
|
||||
--enable_small_channel=1 \
|
||||
--log=error \
|
||||
--soc_version=Ascend310 \
|
||||
--insert_op_conf="$aipp_cfg_path" \
|
||||
--output_type=FP32
|
|
@ -0,0 +1,3 @@
|
|||
CLASS_NUM=1001
|
||||
SOFTMAX=false
|
||||
TOP_K=5
|
|
@ -0,0 +1,48 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
docker_image=$1
|
||||
data_path=$2
|
||||
|
||||
function show_help() {
|
||||
echo "Usage: docker_start.sh docker_image data_path"
|
||||
}
|
||||
|
||||
function param_check() {
|
||||
if [ -z "${docker_image}" ]; then
|
||||
echo "please input docker_image"
|
||||
show_help
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "${data_path}" ]; then
|
||||
echo "please input data_path"
|
||||
show_help
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
param_check
|
||||
|
||||
docker run -it \
|
||||
--device=/dev/davinci0 \
|
||||
--device=/dev/davinci_manager \
|
||||
--device=/dev/devmm_svm \
|
||||
--device=/dev/hisi_hdc \
|
||||
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
|
||||
-v ${data_path}:${data_path} \
|
||||
${docker_image} \
|
||||
/bin/bash
|
|
@ -0,0 +1,48 @@
|
|||
cmake_minimum_required(VERSION 3.14.0)
|
||||
project(resnet)
|
||||
set(TARGET resnet)
|
||||
|
||||
add_definitions(-DENABLE_DVPP_INTERFACE)
|
||||
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
|
||||
add_definitions(-Dgoogle=mindxsdk_private)
|
||||
add_compile_options(-std=c++11 -fPIE -fstack-protector-all -fPIC -Wall)
|
||||
add_link_options(-Wl,-z,relro,-z,now,-z,noexecstack -s -pie)
|
||||
|
||||
# Check environment variable
|
||||
if(NOT DEFINED ENV{ASCEND_HOME})
|
||||
message(FATAL_ERROR "please define environment variable:ASCEND_HOME")
|
||||
endif()
|
||||
if(NOT DEFINED ENV{ASCEND_VERSION})
|
||||
message(WARNING "please define environment variable:ASCEND_VERSION")
|
||||
endif()
|
||||
if(NOT DEFINED ENV{ARCH_PATTERN})
|
||||
message(WARNING "please define environment variable:ARCH_PATTERN")
|
||||
endif()
|
||||
set(ACL_INC_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/include)
|
||||
set(ACL_LIB_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/lib64)
|
||||
|
||||
set(MXBASE_ROOT_DIR $ENV{MX_SDK_HOME})
|
||||
set(MXBASE_INC ${MXBASE_ROOT_DIR}/include)
|
||||
set(MXBASE_LIB_DIR ${MXBASE_ROOT_DIR}/lib)
|
||||
set(MXBASE_POST_LIB_DIR ${MXBASE_ROOT_DIR}/lib/modelpostprocessors)
|
||||
set(MXBASE_POST_PROCESS_DIR ${MXBASE_ROOT_DIR}/include/MxBase/postprocess/include)
|
||||
set(OPENSOURCE_DIR ${MXBASE_ROOT_DIR}/opensource)
|
||||
|
||||
include_directories(${ACL_INC_DIR})
|
||||
include_directories(${OPENSOURCE_DIR}/include)
|
||||
include_directories(${OPENSOURCE_DIR}/include/opencv4)
|
||||
|
||||
include_directories(${MXBASE_INC})
|
||||
include_directories(${MXBASE_POST_PROCESS_DIR})
|
||||
|
||||
link_directories(${ACL_LIB_DIR})
|
||||
link_directories(${OPENSOURCE_DIR}/lib)
|
||||
link_directories(${MXBASE_LIB_DIR})
|
||||
link_directories(${MXBASE_POST_LIB_DIR})
|
||||
|
||||
add_executable(${TARGET} main.cpp Resnet18ClassifyOpencv.cpp)
|
||||
|
||||
target_link_libraries(${TARGET} glog cpprest mxbase resnet50postprocess opencv_world stdc++fs)
|
||||
|
||||
install(TARGETS ${TARGET} RUNTIME DESTINATION ${PROJECT_SOURCE_DIR}/)
|
||||
|
|
@ -0,0 +1,217 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "Resnet18ClassifyOpencv.h"
|
||||
#include "MxBase/DeviceManager/DeviceManager.h"
|
||||
#include "MxBase/Log/Log.h"
|
||||
|
||||
using MxBase::DeviceManager;
|
||||
using MxBase::TensorBase;
|
||||
using MxBase::MemoryData;
|
||||
using MxBase::ClassInfo;
|
||||
|
||||
namespace {
|
||||
const uint32_t YUV_BYTE_NU = 3;
|
||||
const uint32_t YUV_BYTE_DE = 2;
|
||||
const uint32_t VPC_H_ALIGN = 2;
|
||||
}
|
||||
|
||||
APP_ERROR Resnet18ClassifyOpencv::Init(const InitParam &initParam) {
|
||||
deviceId_ = initParam.deviceId;
|
||||
APP_ERROR ret = DeviceManager::GetInstance()->InitDevices();
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "Init devices failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "Set context failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
dvppWrapper_ = std::make_shared<MxBase::DvppWrapper>();
|
||||
ret = dvppWrapper_->Init();
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "DvppWrapper init failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
model_ = std::make_shared<MxBase::ModelInferenceProcessor>();
|
||||
ret = model_->Init(initParam.modelPath, modelDesc_);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "ModelInferenceProcessor init failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
MxBase::ConfigData configData;
|
||||
const std::string softmax = initParam.softmax ? "true" : "false";
|
||||
const std::string checkTensor = initParam.checkTensor ? "true" : "false";
|
||||
|
||||
configData.SetJsonValue("CLASS_NUM", std::to_string(initParam.classNum));
|
||||
configData.SetJsonValue("TOP_K", std::to_string(initParam.topk));
|
||||
configData.SetJsonValue("SOFTMAX", softmax);
|
||||
configData.SetJsonValue("CHECK_MODEL", checkTensor);
|
||||
|
||||
auto jsonStr = configData.GetCfgJson().serialize();
|
||||
std::map<std::string, std::shared_ptr<void>> config;
|
||||
config["postProcessConfigContent"] = std::make_shared<std::string>(jsonStr);
|
||||
config["labelPath"] = std::make_shared<std::string>(initParam.labelPath);
|
||||
|
||||
post_ = std::make_shared<MxBase::Resnet50PostProcess>();
|
||||
ret = post_->Init(config);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "Resnet50PostProcess init failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
return APP_ERR_OK;
|
||||
}
|
||||
|
||||
APP_ERROR Resnet18ClassifyOpencv::DeInit() {
|
||||
dvppWrapper_->DeInit();
|
||||
model_->DeInit();
|
||||
post_->DeInit();
|
||||
DeviceManager::GetInstance()->DestroyDevices();
|
||||
return APP_ERR_OK;
|
||||
}
|
||||
|
||||
APP_ERROR Resnet18ClassifyOpencv::ConvertImageToTensorBase(const std::string &imgPath,
|
||||
TensorBase &tensorBase) {
|
||||
static constexpr uint32_t resizeHeight = 304;
|
||||
static constexpr uint32_t resizeWidth = 304;
|
||||
|
||||
cv::Mat imageMat = cv::imread(imgPath, cv::IMREAD_COLOR);
|
||||
cv::resize(imageMat, imageMat, cv::Size(resizeWidth, resizeHeight));
|
||||
const uint32_t dataSize = imageMat.cols * imageMat.rows * MxBase::XRGB_WIDTH_NU;
|
||||
LogInfo << "image size after resize" << imageMat.cols << " " << imageMat.rows;
|
||||
|
||||
MemoryData memoryDataDst(dataSize, MemoryData::MEMORY_DEVICE, deviceId_);
|
||||
MemoryData memoryDataSrc(imageMat.data, dataSize, MemoryData::MEMORY_HOST_MALLOC);
|
||||
|
||||
APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << GetError(ret) << "Memory malloc failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> shape = {imageMat.rows * MxBase::XRGB_WIDTH_NU, static_cast<uint32_t>(imageMat.cols)};
|
||||
tensorBase = TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_UINT8);
|
||||
return APP_ERR_OK;
|
||||
}
|
||||
|
||||
APP_ERROR Resnet18ClassifyOpencv::Inference(std::vector<TensorBase> &inputs,
|
||||
std::vector<TensorBase> &outputs) {
|
||||
auto dtypes = model_->GetOutputDataType();
|
||||
for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) {
|
||||
std::vector<uint32_t> shape = {};
|
||||
for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) {
|
||||
shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]);
|
||||
}
|
||||
TensorBase tensor(shape, dtypes[i], MemoryData::MemoryType::MEMORY_DEVICE, deviceId_);
|
||||
APP_ERROR ret = TensorBase::TensorBaseMalloc(tensor);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "TensorBaseMalloc failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
outputs.push_back(tensor);
|
||||
}
|
||||
MxBase::DynamicInfo dynamicInfo = {};
|
||||
dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH;
|
||||
auto startTime = std::chrono::high_resolution_clock::now();
|
||||
APP_ERROR ret = model_->ModelInference(inputs, outputs, dynamicInfo);
|
||||
auto endTime = std::chrono::high_resolution_clock::now();
|
||||
double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
|
||||
inferCostTimeMilliSec += costMs;
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "ModelInference failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
return APP_ERR_OK;
|
||||
}
|
||||
|
||||
APP_ERROR Resnet18ClassifyOpencv::PostProcess(std::vector<TensorBase> &inputs,
|
||||
std::vector<std::vector<ClassInfo>> &clsInfos) {
|
||||
APP_ERROR ret = post_->Process(inputs, clsInfos);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "Process failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
return APP_ERR_OK;
|
||||
}
|
||||
|
||||
APP_ERROR Resnet18ClassifyOpencv::SaveResult(const std::string &imgPath,
|
||||
std::vector<std::vector<ClassInfo>> &batchClsInfos) {
|
||||
LogInfo << "image path" << imgPath;
|
||||
std::string fileName = imgPath.substr(imgPath.find_last_of("/") + 1);
|
||||
size_t dot = fileName.find_last_of(".");
|
||||
std::string resFileName = "result/" + fileName.substr(0, dot) + "_1.txt";
|
||||
LogInfo << "file path for saving result" << resFileName;
|
||||
|
||||
std::ofstream outfile(resFileName);
|
||||
if (outfile.fail()) {
|
||||
LogError << "Failed to open result file: ";
|
||||
return APP_ERR_COMM_FAILURE;
|
||||
}
|
||||
|
||||
uint32_t batchIndex = 0;
|
||||
for (auto clsInfos : batchClsInfos) {
|
||||
std::string resultStr;
|
||||
for (auto clsInfo : clsInfos) {
|
||||
LogDebug << " className:" << clsInfo.className << " confidence:" << clsInfo.confidence <<
|
||||
" classIndex:" << clsInfo.classId;
|
||||
resultStr += std::to_string(clsInfo.classId) + " ";
|
||||
}
|
||||
|
||||
outfile << resultStr << std::endl;
|
||||
batchIndex++;
|
||||
}
|
||||
outfile.close();
|
||||
return APP_ERR_OK;
|
||||
}
|
||||
|
||||
APP_ERROR Resnet18ClassifyOpencv::Process(const std::string &imgPath) {
|
||||
TensorBase tensorBase;
|
||||
std::vector<TensorBase> inputs;
|
||||
std::vector<TensorBase> outputs;
|
||||
|
||||
APP_ERROR ret = ConvertImageToTensorBase(imgPath, tensorBase);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "Convert image to TensorBase failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
|
||||
inputs.push_back(tensorBase);
|
||||
|
||||
auto startTime = std::chrono::high_resolution_clock::now();
|
||||
ret = Inference(inputs, outputs);
|
||||
auto endTime = std::chrono::high_resolution_clock::now();
|
||||
double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
|
||||
inferCostTimeMilliSec += costMs;
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "Inference failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
std::vector<std::vector<ClassInfo>> BatchClsInfos;
|
||||
ret = PostProcess(outputs, BatchClsInfos);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "PostProcess failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = SaveResult(imgPath, BatchClsInfos);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "Save infer results into file failed. ret = " << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
|
||||
return APP_ERR_OK;
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MXBASE_RESNET18CLASSIFYOPENCV_H
|
||||
#define MXBASE_RESNET18CLASSIFYOPENCV_H
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <opencv2/opencv.hpp>
|
||||
#include "MxBase/DvppWrapper/DvppWrapper.h"
|
||||
#include "MxBase/ModelInfer/ModelInferenceProcessor.h"
|
||||
#include "ClassPostProcessors/Resnet50PostProcess.h"
|
||||
#include "MxBase/Tensor/TensorContext/TensorContext.h"
|
||||
|
||||
struct InitParam {
|
||||
uint32_t deviceId;
|
||||
std::string labelPath;
|
||||
uint32_t classNum;
|
||||
uint32_t topk;
|
||||
bool softmax;
|
||||
bool checkTensor;
|
||||
std::string modelPath;
|
||||
};
|
||||
|
||||
class Resnet18ClassifyOpencv {
|
||||
public:
|
||||
APP_ERROR Init(const InitParam &initParam);
|
||||
APP_ERROR DeInit();
|
||||
APP_ERROR ConvertImageToTensorBase(const std::string &imgPath, MxBase::TensorBase &tensorBase);
|
||||
APP_ERROR Inference(std::vector<MxBase::TensorBase> &inputs,
|
||||
std::vector<MxBase::TensorBase> &outputs);
|
||||
APP_ERROR PostProcess(std::vector<MxBase::TensorBase> &inputs,
|
||||
std::vector<std::vector<MxBase::ClassInfo>> &clsInfos);
|
||||
APP_ERROR Process(const std::string &imgPath);
|
||||
// get infer time
|
||||
double GetInferCostMilliSec() const {return inferCostTimeMilliSec;}
|
||||
|
||||
private:
|
||||
APP_ERROR SaveResult(const std::string &imgPath,
|
||||
std::vector<std::vector<MxBase::ClassInfo>> &batchClsInfos);
|
||||
std::shared_ptr<MxBase::DvppWrapper> dvppWrapper_;
|
||||
std::shared_ptr<MxBase::ModelInferenceProcessor> model_;
|
||||
std::shared_ptr<MxBase::Resnet50PostProcess> post_;
|
||||
MxBase::ModelDesc modelDesc_;
|
||||
uint32_t deviceId_ = 0;
|
||||
// infer time
|
||||
double inferCostTimeMilliSec = 0.0;
|
||||
};
|
||||
|
||||
#endif
|
|
@ -0,0 +1,53 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
path_cur=$(dirname "$0")
|
||||
|
||||
function check_env()
|
||||
{
|
||||
# set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user
|
||||
if [ ! "${ASCEND_VERSION}" ]; then
|
||||
echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}"
|
||||
else
|
||||
echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user"
|
||||
fi
|
||||
|
||||
if [ ! "${ARCH_PATTERN}" ]; then
|
||||
# set ARCH_PATTERN to ./ when it was not specified by user
|
||||
echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}"
|
||||
else
|
||||
echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user"
|
||||
fi
|
||||
}
|
||||
|
||||
function build_resnet18()
|
||||
{
|
||||
cd "$path_cur" || exit
|
||||
rm -rf build
|
||||
mkdir -p build
|
||||
cd build || exit
|
||||
cmake ..
|
||||
make
|
||||
ret=$?
|
||||
if [ ${ret} -ne 0 ]; then
|
||||
echo "Failed to build resnet18."
|
||||
exit ${ret}
|
||||
fi
|
||||
make install
|
||||
}
|
||||
|
||||
check_env
|
||||
build_resnet18
|
|
@ -0,0 +1,90 @@
|
|||
/*
|
||||
* Copyright (c) 2021. Huawei Technologies Co., Ltd. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <dirent.h>
|
||||
#include "Resnet18ClassifyOpencv.h"
|
||||
#include "MxBase/Log/Log.h"
|
||||
|
||||
namespace {
|
||||
const uint32_t CLASS_NUM = 1001;
|
||||
}
|
||||
|
||||
APP_ERROR ReadFilesFromPath(const std::string &path, std::vector<std::string> *files) {
|
||||
DIR *dir = NULL;
|
||||
struct dirent *ptr = NULL;
|
||||
|
||||
if ((dir=opendir(path.c_str())) == NULL) {
|
||||
LogError << "Open dir error: " << path;
|
||||
return APP_ERR_COMM_OPEN_FAIL;
|
||||
}
|
||||
|
||||
while ((ptr=readdir(dir)) != NULL) {
|
||||
// d_type == 8 is file
|
||||
if (ptr->d_type == 8) {
|
||||
files->push_back(path + ptr->d_name);
|
||||
}
|
||||
}
|
||||
closedir(dir);
|
||||
// sort ascending order
|
||||
sort(files->begin(), files->end());
|
||||
return APP_ERR_OK;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
if (argc <= 1) {
|
||||
LogWarn << "Please input image path, such as './resnet image_dir'.";
|
||||
return APP_ERR_OK;
|
||||
}
|
||||
|
||||
InitParam initParam = {};
|
||||
initParam.deviceId = 0;
|
||||
initParam.classNum = CLASS_NUM;
|
||||
initParam.labelPath = "../data/config/imagenet1000_clsidx_to_labels.names";
|
||||
initParam.topk = 5;
|
||||
initParam.softmax = false;
|
||||
initParam.checkTensor = true;
|
||||
initParam.modelPath = "../data/model/resnet18-304_304.om";
|
||||
auto resnet18 = std::make_shared<Resnet18ClassifyOpencv>();
|
||||
APP_ERROR ret = resnet18->Init(initParam);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "resnet18Classify init failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::string inferPath = argv[1];
|
||||
std::vector<std::string> files;
|
||||
ret = ReadFilesFromPath(inferPath, &files);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "Read files from path failed, ret=" << ret << ".";
|
||||
return ret;
|
||||
}
|
||||
|
||||
auto startTime = std::chrono::high_resolution_clock::now();
|
||||
for (uint32_t i = 0; i < files.size(); i++) {
|
||||
ret = resnet18->Process(files[i]);
|
||||
if (ret != APP_ERR_OK) {
|
||||
LogError << "resnet18Classify process failed, ret=" << ret << ".";
|
||||
resnet18->DeInit();
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
auto endTime = std::chrono::high_resolution_clock::now();
|
||||
resnet18->DeInit();
|
||||
double costMilliSecs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
|
||||
double fps = 1000.0 * files.size() / resnet18->GetInferCostMilliSec();
|
||||
LogInfo << "[Process Delay] cost: " << costMilliSecs << " ms\tfps: " << fps << " imgs/sec";
|
||||
return APP_ERR_OK;
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
# coding=utf-8
|
||||
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
from StreamManagerApi import StreamManagerApi
|
||||
from StreamManagerApi import MxDataInput
|
||||
|
||||
|
||||
def run():
|
||||
# init stream manager
|
||||
stream_manager_api = StreamManagerApi()
|
||||
ret = stream_manager_api.InitManager()
|
||||
if ret != 0:
|
||||
print("Failed to init Stream manager, ret=%s" % str(ret))
|
||||
return
|
||||
|
||||
# create streams by pipeline config file
|
||||
with open("./resnet18.pipeline", 'rb') as f:
|
||||
pipelineStr = f.read()
|
||||
ret = stream_manager_api.CreateMultipleStreams(pipelineStr)
|
||||
|
||||
if ret != 0:
|
||||
print("Failed to create Stream, ret=%s" % str(ret))
|
||||
return
|
||||
|
||||
# Construct the input of the stream
|
||||
data_input = MxDataInput()
|
||||
|
||||
dir_name = sys.argv[1]
|
||||
res_dir_name = sys.argv[2]
|
||||
file_list = os.listdir(dir_name)
|
||||
if not os.path.exists(res_dir_name):
|
||||
os.makedirs(res_dir_name)
|
||||
|
||||
for file_name in file_list:
|
||||
file_path = os.path.join(dir_name, file_name)
|
||||
if not (file_name.lower().endswith(
|
||||
".jpg") or file_name.lower().endswith(".jpeg")):
|
||||
continue
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
data_input.data = f.read()
|
||||
|
||||
stream_name = b'im_resnet18'
|
||||
in_plugin_id = 0
|
||||
unique_id = stream_manager_api.SendData(stream_name, in_plugin_id,
|
||||
data_input)
|
||||
if unique_id < 0:
|
||||
print("Failed to send data to stream.")
|
||||
return
|
||||
# Obtain the inference result by specifying streamName and uniqueId.
|
||||
start_time = datetime.datetime.now()
|
||||
infer_result = stream_manager_api.GetResult(stream_name, unique_id)
|
||||
end_time = datetime.datetime.now()
|
||||
print('sdk run time: {}'.format((end_time - start_time).microseconds))
|
||||
if infer_result.errorCode != 0:
|
||||
print("GetResultWithUniqueId error. errorCode=%d, errorMsg=%s" % (
|
||||
infer_result.errorCode, infer_result.data.decode()))
|
||||
return
|
||||
# print the infer result
|
||||
infer_res = infer_result.data.decode()
|
||||
print("process img: {}, infer result: {}".format(file_name, infer_res))
|
||||
|
||||
load_dict = json.loads(infer_result.data.decode())
|
||||
if load_dict.get('MxpiClass') is None:
|
||||
with open(res_dir_name + "/" + file_name[:-5] + '.txt',
|
||||
'w') as f_write:
|
||||
f_write.write("")
|
||||
continue
|
||||
res_vec = load_dict.get('MxpiClass')
|
||||
|
||||
with open(res_dir_name + "/" + file_name[:-5] + '_1.txt',
|
||||
'w') as f_write:
|
||||
res_list = [str(item.get("classId")) + " " for item in res_vec]
|
||||
f_write.writelines(res_list)
|
||||
f_write.write('\n')
|
||||
|
||||
# destroy streams
|
||||
stream_manager_api.DestroyAllStreams()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
|
@ -0,0 +1,177 @@
|
|||
# coding = utf-8
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the BSD 3-Clause License (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://opensource.org/licenses/BSD-3-Clause
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
np.set_printoptions(threshold=sys.maxsize)
|
||||
|
||||
LABEL_FILE = "HiAI_label.json"
|
||||
|
||||
|
||||
def gen_file_name(img_name):
|
||||
full_name = img_name.split('/')[-1]
|
||||
return os.path.splitext(full_name)
|
||||
|
||||
|
||||
def cre_groundtruth_dict(gtfile_path):
|
||||
"""
|
||||
:param filename: file contains the imagename and label number
|
||||
:return: dictionary key imagename, value is label number
|
||||
"""
|
||||
img_gt_dict = {}
|
||||
for gtfile in os.listdir(gtfile_path):
|
||||
if gtfile != LABEL_FILE:
|
||||
with open(os.path.join(gtfile_path, gtfile), 'r') as f:
|
||||
gt = json.load(f)
|
||||
ret = gt["image"]["annotations"][0]["category_id"]
|
||||
img_gt_dict[gen_file_name(gtfile)] = ret
|
||||
return img_gt_dict
|
||||
|
||||
|
||||
def cre_groundtruth_dict_fromtxt(gtfile_path):
|
||||
"""
|
||||
:param filename: file contains the imagename and label number
|
||||
:return: dictionary key imagename, value is label number
|
||||
"""
|
||||
img_gt_dict = {}
|
||||
with open(gtfile_path, 'r')as f:
|
||||
for line in f.readlines():
|
||||
temp = line.strip().split(" ")
|
||||
img_name = temp[0].split(".")[0]
|
||||
img_lab = temp[1]
|
||||
img_gt_dict[img_name] = img_lab
|
||||
return img_gt_dict
|
||||
|
||||
|
||||
def load_statistical_predict_result(filepath):
|
||||
"""
|
||||
function:
|
||||
the prediction esult file data extraction
|
||||
input:
|
||||
result file:filepath
|
||||
output:
|
||||
n_label:numble of label
|
||||
data_vec: the probabilitie of prediction in the 1000
|
||||
:return: probabilities, numble of label, in_type, color
|
||||
"""
|
||||
with open(filepath, 'r')as f:
|
||||
data = f.readline()
|
||||
temp = data.strip().split(" ")
|
||||
n_label = len(temp)
|
||||
data_vec = np.zeros((n_label), dtype=np.float32)
|
||||
in_type = ''
|
||||
color = ''
|
||||
if n_label == 0:
|
||||
in_type = f.readline()
|
||||
color = f.readline()
|
||||
else:
|
||||
for ind, cls_ind in enumerate(temp):
|
||||
data_vec[ind] = np.int32(cls_ind)
|
||||
return data_vec, n_label, in_type, color
|
||||
|
||||
|
||||
def create_visualization_statistical_result(prediction_file_path,
|
||||
result_store_path, file_name,
|
||||
img_gt_dict, topn=5):
|
||||
"""
|
||||
:param prediction_file_path:
|
||||
:param result_store_path:
|
||||
:param file_name:
|
||||
:param img_gt_dict:
|
||||
:param topn:
|
||||
:return:
|
||||
"""
|
||||
writer = open(os.path.join(result_store_path, file_name), 'w')
|
||||
table_dict = {"title": "Overall statistical evaluation", "value": []}
|
||||
|
||||
count = 0
|
||||
res_cnt = 0
|
||||
n_labels = ""
|
||||
count_hit = np.zeros(topn)
|
||||
for tfile_name in os.listdir(prediction_file_path):
|
||||
count += 1
|
||||
temp = tfile_name.split('.')[0]
|
||||
index = temp.rfind('_')
|
||||
img_name = temp[:index]
|
||||
filepath = os.path.join(prediction_file_path, tfile_name)
|
||||
|
||||
ret = load_statistical_predict_result(filepath)
|
||||
prediction = ret[0]
|
||||
n_labels = ret[1]
|
||||
|
||||
gt = img_gt_dict[img_name]
|
||||
if n_labels == 1000:
|
||||
real_label = int(gt)
|
||||
elif n_labels == 1001:
|
||||
real_label = int(gt) + 1
|
||||
else:
|
||||
real_label = int(gt)
|
||||
|
||||
res_cnt = min(len(prediction), topn)
|
||||
for i in range(res_cnt):
|
||||
if str(real_label) == str(int(prediction[i])):
|
||||
count_hit[i] += 1
|
||||
break
|
||||
if 'value' not in table_dict.keys():
|
||||
print("the item value does not exist!")
|
||||
else:
|
||||
table_dict["value"].extend(
|
||||
[{"key": "Number of images", "value": str(count)},
|
||||
{"key": "Number of classes", "value": str(n_labels)}])
|
||||
if count == 0:
|
||||
accuracy = 0
|
||||
else:
|
||||
accuracy = np.cumsum(count_hit) / count
|
||||
for i in range(res_cnt):
|
||||
table_dict["value"].append({"key": "Top" + str(i + 1) + " accuracy",
|
||||
"value": str(
|
||||
round(accuracy[i] * 100, 2)) + '%'})
|
||||
json.dump(table_dict, writer)
|
||||
writer.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
# txt file path
|
||||
folder_davinci_target = sys.argv[1]
|
||||
# annotation files path, "val_label.txt"
|
||||
annotation_file_path = sys.argv[2]
|
||||
# the path to store the results json path
|
||||
result_json_path = sys.argv[3]
|
||||
# result json file name
|
||||
json_file_name = sys.argv[4]
|
||||
except IndexError:
|
||||
print("Please enter target file result folder | ground truth label file"
|
||||
"| result json file folder | "
|
||||
"result json file name, such as "
|
||||
"./result val_label.txt . result.json")
|
||||
exit(1)
|
||||
|
||||
if not os.path.exists(folder_davinci_target):
|
||||
print("Target file folder does not exist.")
|
||||
|
||||
if not os.path.exists(annotation_file_path):
|
||||
print("Ground truth file does not exist.")
|
||||
|
||||
if not os.path.exists(result_json_path):
|
||||
print("Result folder doesn't exist.")
|
||||
|
||||
img_label_dict = cre_groundtruth_dict_fromtxt(annotation_file_path)
|
||||
create_visualization_statistical_result(folder_davinci_target,
|
||||
result_json_path, json_file_name,
|
||||
img_label_dict, topn=5)
|
|
@ -0,0 +1 @@
|
|||
{"title": "Overall statistical evaluation", "value": [{"key": "Number of images", "value": "1000"}, {"key": "Number of classes", "value": "5"}, {"key": "Top1 accuracy", "value": "69.3%"}, {"key": "Top2 accuracy", "value": "80.1%"}, {"key": "Top3 accuracy", "value": "84.2%"}, {"key": "Top4 accuracy", "value": "86.2%"}, {"key": "Top5 accuracy", "value": "87.8%"}]}
|
|
@ -0,0 +1,64 @@
|
|||
{
|
||||
"im_resnet18": {
|
||||
"stream_config": {
|
||||
"deviceId": "0"
|
||||
},
|
||||
"appsrc1": {
|
||||
"props": {
|
||||
"blocksize": "409600"
|
||||
},
|
||||
"factory": "appsrc",
|
||||
"next": "mxpi_imagedecoder0"
|
||||
},
|
||||
"mxpi_imagedecoder0": {
|
||||
"props": {
|
||||
"handleMethod": "opencv"
|
||||
},
|
||||
"factory": "mxpi_imagedecoder",
|
||||
"next": "mxpi_imageresize0"
|
||||
},
|
||||
"mxpi_imageresize0": {
|
||||
"props": {
|
||||
"handleMethod": "opencv",
|
||||
"resizeType": "Resizer_Stretch",
|
||||
"resizeHeight": "304",
|
||||
"resizeWidth": "304"
|
||||
},
|
||||
"factory": "mxpi_imageresize",
|
||||
"next": "mxpi_tensorinfer0"
|
||||
},
|
||||
"mxpi_tensorinfer0": {
|
||||
"props": {
|
||||
"dataSource": "mxpi_imageresize0",
|
||||
"modelPath": "../data/model/resnet18-304_304.om",
|
||||
"waitingTime": "2000",
|
||||
"outputDeviceId": "-1"
|
||||
},
|
||||
"factory": "mxpi_tensorinfer",
|
||||
"next": "mxpi_classpostprocessor0"
|
||||
},
|
||||
"mxpi_classpostprocessor0": {
|
||||
"props": {
|
||||
"dataSource": "mxpi_tensorinfer0",
|
||||
"postProcessConfigPath": "../data/config/resnet18.cfg",
|
||||
"labelPath": "../data/config/imagenet1000_clsidx_to_labels.names",
|
||||
"postProcessLibPath": "../../../lib/modelpostprocessors/libresnet50postprocess.so"
|
||||
},
|
||||
"factory": "mxpi_classpostprocessor",
|
||||
"next": "mxpi_dataserialize0"
|
||||
},
|
||||
"mxpi_dataserialize0": {
|
||||
"props": {
|
||||
"outputDataKeys": "mxpi_classpostprocessor0"
|
||||
},
|
||||
"factory": "mxpi_dataserialize",
|
||||
"next": "appsink0"
|
||||
},
|
||||
"appsink0": {
|
||||
"props": {
|
||||
"blocksize": "4096000"
|
||||
},
|
||||
"factory": "appsink"
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
image_path=$1
|
||||
result_dir=$2
|
||||
|
||||
set -e
|
||||
|
||||
# Simple log helper functions
|
||||
info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; }
|
||||
warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; }
|
||||
|
||||
export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner
|
||||
export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins
|
||||
|
||||
#to set PYTHONPATH, import the StreamManagerApi.py
|
||||
export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python
|
||||
|
||||
python3.7 main.py $image_path $result_dir
|
||||
exit 0
|
|
@ -0,0 +1,440 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train resnet."""
|
||||
import os
|
||||
import argparse
|
||||
import ast
|
||||
import moxing as mox
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import export
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.optim import Momentum, thor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.train_thor import ConvertModelUtils
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.parallel import set_algo_parameters
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.initializer as weight_init
|
||||
import mindspore.log as logger
|
||||
from src.lr_generator import get_lr, warmup_cosine_annealing_lr, get_resnet34_lr
|
||||
from src.CrossEntropySmooth import CrossEntropySmooth
|
||||
from src.config import cfg
|
||||
from src.eval_callback import EvalCallBack
|
||||
from src.metric import DistAccuracy, ClassifyCorrectCell
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--train_url', type=str, default='',
|
||||
help='the path model saved')
|
||||
parser.add_argument('--data_url', type=str, default='',
|
||||
help='the training data')
|
||||
|
||||
parser.add_argument('--net', type=str, default="resnet18",
|
||||
help='Resnet Model, resnet18, resnet34, '
|
||||
'resnet50 or resnet101')
|
||||
parser.add_argument('--dataset', type=str, default="imagenet2012",
|
||||
help='Dataset, either cifar10 or imagenet2012')
|
||||
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False,
|
||||
help='Run distribute')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
|
||||
|
||||
parser.add_argument('--dataset_path', type=str, default="/cache",
|
||||
help='Dataset path')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend',
|
||||
choices=("Ascend", "GPU", "CPU"),
|
||||
help="Device target, support Ascend, GPU and CPU.")
|
||||
parser.add_argument('--pre_trained', type=str, default=None,
|
||||
help='Pretrained checkpoint path')
|
||||
parser.add_argument('--parameter_server', type=ast.literal_eval, default=False,
|
||||
help='Run parameter server train')
|
||||
parser.add_argument("--filter_weight", type=ast.literal_eval, default=False,
|
||||
help="Filter head weight parameters, default is False.")
|
||||
parser.add_argument("--run_eval", type=ast.literal_eval, default=False,
|
||||
help="Run evaluation when training, default is False.")
|
||||
parser.add_argument('--eval_dataset_path', type=str, default=None,
|
||||
help='Evaluation dataset path when run_eval is True')
|
||||
parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True,
|
||||
help="Save best checkpoint when run_eval is True, "
|
||||
"default is True.")
|
||||
parser.add_argument("--eval_start_epoch", type=int, default=40,
|
||||
help="Evaluation start epoch when run_eval is True, "
|
||||
"default is 40.")
|
||||
parser.add_argument("--eval_interval", type=int, default=1,
|
||||
help="Evaluation interval when run_eval is True, "
|
||||
"default is 1.")
|
||||
parser.add_argument('--enable_cache', type=ast.literal_eval, default=False,
|
||||
help='Caching the eval dataset in memory to speedup '
|
||||
'evaluation, default is False.')
|
||||
parser.add_argument('--cache_session_id', type=str, default="",
|
||||
help='The session id for cache service.')
|
||||
parser.add_argument('--mode', type=str, default='GRAPH',
|
||||
choices=('GRAPH', 'PYNATIVE'),
|
||||
help="Graph mode or PyNative mode, default is Graph mode")
|
||||
|
||||
parser.add_argument("--epoch_size", type=int, default=1,
|
||||
help="training epoch size, default is 1.")
|
||||
parser.add_argument("--num_classes", type=int, default=1001,
|
||||
help="number of dataset categories, default is 1001.")
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
CKPT_OUTPUT_PATH = "./"
|
||||
|
||||
set_seed(1)
|
||||
|
||||
if args_opt.net in ("resnet18", "resnet50"):
|
||||
if args_opt.net == "resnet18":
|
||||
from src.resnet import resnet18 as resnet
|
||||
if args_opt.net == "resnet50":
|
||||
from src.resnet import resnet50 as resnet
|
||||
if args_opt.dataset == "cifar10":
|
||||
from src.config import config1 as config
|
||||
from src.dataset import create_dataset1 as create_dataset
|
||||
else:
|
||||
from src.config import config2 as config
|
||||
if args_opt.mode == "GRAPH":
|
||||
from src.dataset import create_dataset2 as create_dataset
|
||||
else:
|
||||
from src.dataset import create_dataset_pynative as create_dataset
|
||||
elif args_opt.net == "resnet34":
|
||||
from src.resnet import resnet34 as resnet
|
||||
from src.config import config_resnet34 as config
|
||||
from src.dataset import create_dataset_resnet34 as create_dataset
|
||||
elif args_opt.net == "resnet101":
|
||||
from src.resnet import resnet101 as resnet
|
||||
from src.config import config3 as config
|
||||
from src.dataset import create_dataset3 as create_dataset
|
||||
else:
|
||||
from src.resnet import se_resnet50 as resnet
|
||||
from src.config import config4 as config
|
||||
from src.dataset import create_dataset4 as create_dataset
|
||||
|
||||
if cfg.optimizer == "Thor":
|
||||
if args_opt.device_target == "Ascend":
|
||||
from src.config import config_thor_Ascend as config
|
||||
else:
|
||||
from src.config import config_thor_gpu as config
|
||||
|
||||
|
||||
def filter_checkpoint_parameter_by_list(origin_dict, param_filter):
|
||||
"""remove useless parameters according to filter_list"""
|
||||
for key in list(origin_dict.keys()):
|
||||
for name in param_filter:
|
||||
if name in key:
|
||||
print("Delete parameter from checkpoint: ", key)
|
||||
del origin_dict[key]
|
||||
break
|
||||
|
||||
|
||||
def apply_eval(eval_param):
|
||||
eval_model = eval_param["model"]
|
||||
eval_ds = eval_param["dataset"]
|
||||
metrics_name = eval_param["metrics_name"]
|
||||
res = eval_model.eval(eval_ds)
|
||||
return res[metrics_name]
|
||||
|
||||
|
||||
def set_graph_kernel_context(run_platform, net_name):
|
||||
if run_platform == "GPU" and net_name == "resnet101":
|
||||
context.set_context(enable_graph_kernel=True,
|
||||
graph_kernel_flags="--enable_parallel_fusion")
|
||||
|
||||
|
||||
def _get_last_ckpt(ckpt_dir):
|
||||
ckpt_files = [ckpt_file for ckpt_file in os.listdir(ckpt_dir)
|
||||
if ckpt_file.endswith('.ckpt')]
|
||||
if not ckpt_files:
|
||||
print("No ckpt file found.")
|
||||
return None
|
||||
|
||||
return os.path.join(ckpt_dir, sorted(ckpt_files)[-1])
|
||||
|
||||
|
||||
def _export_air(ckpt_dir):
|
||||
ckpt_file = _get_last_ckpt(ckpt_dir)
|
||||
if not ckpt_file:
|
||||
return
|
||||
net = resnet(config.class_num)
|
||||
param_dict = load_checkpoint(ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.zeros([1, 3, 304, 304],
|
||||
np.float32))
|
||||
export(net, input_arr, file_name="resnet",
|
||||
file_format="AIR")
|
||||
|
||||
|
||||
def set_config():
|
||||
config.epoch_size = args_opt.epoch_size
|
||||
config.num_classes = args_opt.num_classes
|
||||
|
||||
|
||||
def init_context(target):
|
||||
if args_opt.mode == 'GRAPH':
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target,
|
||||
save_graphs=False)
|
||||
set_graph_kernel_context(target, args_opt.net)
|
||||
else:
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target=target,
|
||||
save_graphs=False)
|
||||
if args_opt.parameter_server:
|
||||
context.set_ps_context(enable_ps=True)
|
||||
if args_opt.run_distribute:
|
||||
if target == "Ascend":
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id,
|
||||
enable_auto_mixed_precision=True)
|
||||
context.set_auto_parallel_context(
|
||||
device_num=args_opt.device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||
if args_opt.net == "resnet50" or args_opt.net == "se-resnet50":
|
||||
context.set_auto_parallel_context(
|
||||
all_reduce_fusion_config=[85, 160])
|
||||
elif args_opt.net == "resnet101":
|
||||
context.set_auto_parallel_context(
|
||||
all_reduce_fusion_config=[80, 210, 313])
|
||||
init()
|
||||
# GPU target
|
||||
else:
|
||||
init()
|
||||
context.set_auto_parallel_context(
|
||||
device_num=get_group_size(),
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
if args_opt.net == "resnet50":
|
||||
context.set_auto_parallel_context(
|
||||
all_reduce_fusion_config=[85, 160])
|
||||
|
||||
|
||||
def init_weight(net):
|
||||
if os.path.exists(args_opt.pre_trained):
|
||||
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||
if args_opt.filter_weight:
|
||||
filter_list = [x.name for x in net.end_point.get_parameters()]
|
||||
filter_checkpoint_parameter_by_list(param_dict, filter_list)
|
||||
load_param_into_net(net, param_dict)
|
||||
else:
|
||||
for _, cell in net.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
cell.weight.set_data(
|
||||
weight_init.initializer(weight_init.XavierUniform(),
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype))
|
||||
if isinstance(cell, nn.Dense):
|
||||
cell.weight.set_data(
|
||||
weight_init.initializer(weight_init.TruncatedNormal(),
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype))
|
||||
|
||||
|
||||
def init_lr(step_size):
|
||||
if cfg.optimizer == "Thor":
|
||||
from src.lr_generator import get_thor_lr
|
||||
lr = get_thor_lr(0, config.lr_init, config.lr_decay,
|
||||
config.lr_end_epoch, step_size, decay_epochs=39)
|
||||
else:
|
||||
if args_opt.net in ("resnet18", "resnet34", "resnet50", "se-resnet50"):
|
||||
lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end,
|
||||
lr_max=config.lr_max,
|
||||
warmup_epochs=config.warmup_epochs,
|
||||
total_epochs=config.epoch_size,
|
||||
steps_per_epoch=step_size,
|
||||
lr_decay_mode=config.lr_decay_mode)
|
||||
else:
|
||||
lr = warmup_cosine_annealing_lr(
|
||||
config.lr, step_size, config.warmup_epochs, config.epoch_size,
|
||||
config.pretrain_epoch_size * step_size)
|
||||
if args_opt.net == "resnet34":
|
||||
lr = get_resnet34_lr(lr_init=config.lr_init,
|
||||
lr_end=config.lr_end,
|
||||
lr_max=config.lr_max,
|
||||
warmup_epochs=config.warmup_epochs,
|
||||
total_epochs=config.epoch_size,
|
||||
steps_per_epoch=step_size)
|
||||
return Tensor(lr)
|
||||
|
||||
|
||||
def define_opt(net, lr):
|
||||
decayed_params = []
|
||||
no_decayed_params = []
|
||||
for param in net.trainable_params():
|
||||
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' \
|
||||
not in param.name:
|
||||
decayed_params.append(param)
|
||||
else:
|
||||
no_decayed_params.append(param)
|
||||
|
||||
group_params = [
|
||||
{'params': decayed_params, 'weight_decay': config.weight_decay},
|
||||
{'params': no_decayed_params},
|
||||
{'order_params': net.trainable_params()}]
|
||||
opt = Momentum(group_params, lr, config.momentum,
|
||||
loss_scale=config.loss_scale)
|
||||
return opt
|
||||
|
||||
|
||||
def define_model(net, opt, target):
|
||||
if args_opt.dataset == "imagenet2012":
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
loss = CrossEntropySmooth(sparse=True, reduction="mean",
|
||||
smooth_factor=config.label_smooth_factor,
|
||||
num_classes=config.class_num)
|
||||
else:
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale,
|
||||
drop_overflow_update=False)
|
||||
dist_eval_network = ClassifyCorrectCell(
|
||||
net) if args_opt.run_distribute else None
|
||||
metrics = {"acc"}
|
||||
if args_opt.run_distribute:
|
||||
metrics = {'acc': DistAccuracy(batch_size=config.batch_size,
|
||||
device_num=args_opt.device_num)}
|
||||
if (args_opt.net not in ("resnet18", "resnet50", "resnet101",
|
||||
"se-resnet50")) or args_opt.parameter_server \
|
||||
or target == "CPU":
|
||||
# fp32 training
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics=metrics,
|
||||
eval_network=dist_eval_network)
|
||||
else:
|
||||
model = Model(net, loss_fn=loss, optimizer=opt,
|
||||
loss_scale_manager=loss_scale, metrics=metrics,
|
||||
amp_level="O2", keep_batchnorm_fp32=False,
|
||||
eval_network=dist_eval_network)
|
||||
return model, loss, loss_scale
|
||||
|
||||
|
||||
def run_eval(model, target, ckpt_save_dir):
|
||||
if args_opt.eval_dataset_path is None \
|
||||
or (not os.path.isdir(args_opt.eval_dataset_path)):
|
||||
raise ValueError(
|
||||
"{} is not a existing path.".format(args_opt.eval_dataset_path))
|
||||
eval_dataset = create_dataset(
|
||||
dataset_path=args_opt.eval_dataset_path,
|
||||
do_train=False,
|
||||
batch_size=config.batch_size,
|
||||
target=target,
|
||||
enable_cache=args_opt.enable_cache,
|
||||
cache_session_id=args_opt.cache_session_id)
|
||||
eval_param_dict = {"model": model, "dataset": eval_dataset,
|
||||
"metrics_name": "acc"}
|
||||
eval_cb = EvalCallBack(apply_eval, eval_param_dict,
|
||||
interval=args_opt.eval_interval,
|
||||
eval_start_epoch=args_opt.eval_start_epoch,
|
||||
save_best_ckpt=args_opt.save_best_ckpt,
|
||||
ckpt_directory=ckpt_save_dir,
|
||||
besk_ckpt_name="best_acc.ckpt",
|
||||
metrics_name="acc")
|
||||
return eval_cb
|
||||
|
||||
|
||||
def main():
|
||||
set_config()
|
||||
target = args_opt.device_target
|
||||
if target == "CPU":
|
||||
args_opt.run_distribute = False
|
||||
|
||||
# init context
|
||||
init_context(target)
|
||||
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(
|
||||
get_rank()) + "/"
|
||||
|
||||
# create dataset
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True,
|
||||
repeat_num=1,
|
||||
batch_size=config.batch_size, target=target,
|
||||
distribute=args_opt.run_distribute)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
# define net
|
||||
net = resnet(class_num=config.class_num)
|
||||
if args_opt.parameter_server:
|
||||
net.set_param_ps()
|
||||
|
||||
# init weight
|
||||
init_weight(net)
|
||||
|
||||
# init lr
|
||||
lr = init_lr(step_size)
|
||||
|
||||
# define opt
|
||||
opt = define_opt(net, lr)
|
||||
|
||||
# define model
|
||||
model, loss, loss_scale = define_model(net, opt, target)
|
||||
|
||||
if cfg.optimizer == "Thor" and args_opt.dataset == "imagenet2012":
|
||||
from src.lr_generator import get_thor_damping
|
||||
damping = get_thor_damping(0, config.damping_init, config.damping_decay,
|
||||
70, step_size)
|
||||
split_indices = [26, 53]
|
||||
opt = thor(net, lr, Tensor(damping), config.momentum,
|
||||
config.weight_decay, config.loss_scale,
|
||||
config.batch_size, split_indices=split_indices,
|
||||
frequency=config.frequency)
|
||||
model = ConvertModelUtils().convert_to_thor_model(
|
||||
model=model, network=net, loss_fn=loss, optimizer=opt,
|
||||
loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O2",
|
||||
keep_batchnorm_fp32=False)
|
||||
args_opt.run_eval = False
|
||||
logger.warning("Thor optimizer not support evaluation while training.")
|
||||
|
||||
# define callbacks
|
||||
time_cb = TimeMonitor(data_size=step_size)
|
||||
loss_cb = LossMonitor()
|
||||
cb = [time_cb, loss_cb]
|
||||
if config.save_checkpoint:
|
||||
config_ck = CheckpointConfig(
|
||||
save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir,
|
||||
config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
if args_opt.run_eval:
|
||||
eval_cb = run_eval(model, target, ckpt_save_dir)
|
||||
cb += [eval_cb]
|
||||
# train model
|
||||
if args_opt.net == "se-resnet50":
|
||||
config.epoch_size = config.train_epoch_size
|
||||
dataset_sink_mode = (not args_opt.parameter_server) and target != "CPU"
|
||||
model.train(config.epoch_size - config.pretrain_epoch_size, dataset,
|
||||
callbacks=cb,
|
||||
sink_size=dataset.get_dataset_size(),
|
||||
dataset_sink_mode=dataset_sink_mode)
|
||||
|
||||
if args_opt.run_eval and args_opt.enable_cache:
|
||||
print(
|
||||
"Remember to shut down the cache server via \"cache_admin --stop\"")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 将数据集拷贝到ModelArts指定读取的cache目录
|
||||
mox.file.copy_parallel(args_opt.data_url, '/cache')
|
||||
main()
|
||||
# 训练完成后把生成的模型拷贝到指导输出目录
|
||||
if not os.path.exists(CKPT_OUTPUT_PATH):
|
||||
os.makedirs(CKPT_OUTPUT_PATH, exist_ok=True)
|
||||
_export_air(CKPT_OUTPUT_PATH)
|
||||
mox.file.copy_parallel(CKPT_OUTPUT_PATH, args_opt.train_url)
|
Loading…
Reference in New Issue