!22979 ResNext18模型增加ModelArts训练、SDK和MxBase推理功能

Merge pull request !22979 from Atlas_ymc/master
This commit is contained in:
i-robot 2021-09-14 06:22:04 +00:00 committed by Gitee
commit c7a60b2e4c
16 changed files with 1386 additions and 0 deletions

View File

@ -18,6 +18,7 @@
# Modelzoo
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/Yolov4TinyDetection.h" "runtime/references"
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/PostProcess/Yolov4TinyMindsporePost.h" "runtime/references"
"mindspore/model_zoo/official/cv/resnet/infer/ResNet18/mxbase/Resnet18ClassifyOpencv.h" "runtime/references"
# MindData
"mindspore/mindspore/ccsrc/minddata/mindrecord/include/shard_page.h" "runtime/string"

View File

@ -0,0 +1,16 @@
aipp_op {
aipp_mode: static
input_format : RGB888_U8
rbuv_swap_switch : true
mean_chn_0 : 0
mean_chn_1 : 0
mean_chn_2 : 0
min_chn_0 : 123.675
min_chn_1 : 116.28
min_chn_2 : 103.53
var_reci_chn_0 : 0.0171247538316637
var_reci_chn_1 : 0.0175070028011204
var_reci_chn_2 : 0.0174291938997821
}

View File

@ -0,0 +1,30 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
air_path=$1
aipp_cfg_path=$2
om_path=$3
/usr/local/Ascend/atc/bin/atc \
--model="$air_path" \
--framework=1 \
--output="$om_path" \
--input_format=NCHW --input_shape="actual_input_1:1,3,304,304" \
--enable_small_channel=1 \
--log=error \
--soc_version=Ascend310 \
--insert_op_conf="$aipp_cfg_path" \
--output_type=FP32

View File

@ -0,0 +1,3 @@
CLASS_NUM=1001
SOFTMAX=false
TOP_K=5

View File

@ -0,0 +1,48 @@
#!/usr/bin/env bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
docker_image=$1
data_path=$2
function show_help() {
echo "Usage: docker_start.sh docker_image data_path"
}
function param_check() {
if [ -z "${docker_image}" ]; then
echo "please input docker_image"
show_help
exit 1
fi
if [ -z "${data_path}" ]; then
echo "please input data_path"
show_help
exit 1
fi
}
param_check
docker run -it \
--device=/dev/davinci0 \
--device=/dev/davinci_manager \
--device=/dev/devmm_svm \
--device=/dev/hisi_hdc \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
-v ${data_path}:${data_path} \
${docker_image} \
/bin/bash

View File

@ -0,0 +1,48 @@
cmake_minimum_required(VERSION 3.14.0)
project(resnet)
set(TARGET resnet)
add_definitions(-DENABLE_DVPP_INTERFACE)
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
add_definitions(-Dgoogle=mindxsdk_private)
add_compile_options(-std=c++11 -fPIE -fstack-protector-all -fPIC -Wall)
add_link_options(-Wl,-z,relro,-z,now,-z,noexecstack -s -pie)
# Check environment variable
if(NOT DEFINED ENV{ASCEND_HOME})
message(FATAL_ERROR "please define environment variable:ASCEND_HOME")
endif()
if(NOT DEFINED ENV{ASCEND_VERSION})
message(WARNING "please define environment variable:ASCEND_VERSION")
endif()
if(NOT DEFINED ENV{ARCH_PATTERN})
message(WARNING "please define environment variable:ARCH_PATTERN")
endif()
set(ACL_INC_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/include)
set(ACL_LIB_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/lib64)
set(MXBASE_ROOT_DIR $ENV{MX_SDK_HOME})
set(MXBASE_INC ${MXBASE_ROOT_DIR}/include)
set(MXBASE_LIB_DIR ${MXBASE_ROOT_DIR}/lib)
set(MXBASE_POST_LIB_DIR ${MXBASE_ROOT_DIR}/lib/modelpostprocessors)
set(MXBASE_POST_PROCESS_DIR ${MXBASE_ROOT_DIR}/include/MxBase/postprocess/include)
set(OPENSOURCE_DIR ${MXBASE_ROOT_DIR}/opensource)
include_directories(${ACL_INC_DIR})
include_directories(${OPENSOURCE_DIR}/include)
include_directories(${OPENSOURCE_DIR}/include/opencv4)
include_directories(${MXBASE_INC})
include_directories(${MXBASE_POST_PROCESS_DIR})
link_directories(${ACL_LIB_DIR})
link_directories(${OPENSOURCE_DIR}/lib)
link_directories(${MXBASE_LIB_DIR})
link_directories(${MXBASE_POST_LIB_DIR})
add_executable(${TARGET} main.cpp Resnet18ClassifyOpencv.cpp)
target_link_libraries(${TARGET} glog cpprest mxbase resnet50postprocess opencv_world stdc++fs)
install(TARGETS ${TARGET} RUNTIME DESTINATION ${PROJECT_SOURCE_DIR}/)

View File

@ -0,0 +1,217 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "Resnet18ClassifyOpencv.h"
#include "MxBase/DeviceManager/DeviceManager.h"
#include "MxBase/Log/Log.h"
using MxBase::DeviceManager;
using MxBase::TensorBase;
using MxBase::MemoryData;
using MxBase::ClassInfo;
namespace {
const uint32_t YUV_BYTE_NU = 3;
const uint32_t YUV_BYTE_DE = 2;
const uint32_t VPC_H_ALIGN = 2;
}
APP_ERROR Resnet18ClassifyOpencv::Init(const InitParam &initParam) {
deviceId_ = initParam.deviceId;
APP_ERROR ret = DeviceManager::GetInstance()->InitDevices();
if (ret != APP_ERR_OK) {
LogError << "Init devices failed, ret=" << ret << ".";
return ret;
}
ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId);
if (ret != APP_ERR_OK) {
LogError << "Set context failed, ret=" << ret << ".";
return ret;
}
dvppWrapper_ = std::make_shared<MxBase::DvppWrapper>();
ret = dvppWrapper_->Init();
if (ret != APP_ERR_OK) {
LogError << "DvppWrapper init failed, ret=" << ret << ".";
return ret;
}
model_ = std::make_shared<MxBase::ModelInferenceProcessor>();
ret = model_->Init(initParam.modelPath, modelDesc_);
if (ret != APP_ERR_OK) {
LogError << "ModelInferenceProcessor init failed, ret=" << ret << ".";
return ret;
}
MxBase::ConfigData configData;
const std::string softmax = initParam.softmax ? "true" : "false";
const std::string checkTensor = initParam.checkTensor ? "true" : "false";
configData.SetJsonValue("CLASS_NUM", std::to_string(initParam.classNum));
configData.SetJsonValue("TOP_K", std::to_string(initParam.topk));
configData.SetJsonValue("SOFTMAX", softmax);
configData.SetJsonValue("CHECK_MODEL", checkTensor);
auto jsonStr = configData.GetCfgJson().serialize();
std::map<std::string, std::shared_ptr<void>> config;
config["postProcessConfigContent"] = std::make_shared<std::string>(jsonStr);
config["labelPath"] = std::make_shared<std::string>(initParam.labelPath);
post_ = std::make_shared<MxBase::Resnet50PostProcess>();
ret = post_->Init(config);
if (ret != APP_ERR_OK) {
LogError << "Resnet50PostProcess init failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR Resnet18ClassifyOpencv::DeInit() {
dvppWrapper_->DeInit();
model_->DeInit();
post_->DeInit();
DeviceManager::GetInstance()->DestroyDevices();
return APP_ERR_OK;
}
APP_ERROR Resnet18ClassifyOpencv::ConvertImageToTensorBase(const std::string &imgPath,
TensorBase &tensorBase) {
static constexpr uint32_t resizeHeight = 304;
static constexpr uint32_t resizeWidth = 304;
cv::Mat imageMat = cv::imread(imgPath, cv::IMREAD_COLOR);
cv::resize(imageMat, imageMat, cv::Size(resizeWidth, resizeHeight));
const uint32_t dataSize = imageMat.cols * imageMat.rows * MxBase::XRGB_WIDTH_NU;
LogInfo << "image size after resize" << imageMat.cols << " " << imageMat.rows;
MemoryData memoryDataDst(dataSize, MemoryData::MEMORY_DEVICE, deviceId_);
MemoryData memoryDataSrc(imageMat.data, dataSize, MemoryData::MEMORY_HOST_MALLOC);
APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc);
if (ret != APP_ERR_OK) {
LogError << GetError(ret) << "Memory malloc failed.";
return ret;
}
std::vector<uint32_t> shape = {imageMat.rows * MxBase::XRGB_WIDTH_NU, static_cast<uint32_t>(imageMat.cols)};
tensorBase = TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_UINT8);
return APP_ERR_OK;
}
APP_ERROR Resnet18ClassifyOpencv::Inference(std::vector<TensorBase> &inputs,
std::vector<TensorBase> &outputs) {
auto dtypes = model_->GetOutputDataType();
for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) {
std::vector<uint32_t> shape = {};
for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) {
shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]);
}
TensorBase tensor(shape, dtypes[i], MemoryData::MemoryType::MEMORY_DEVICE, deviceId_);
APP_ERROR ret = TensorBase::TensorBaseMalloc(tensor);
if (ret != APP_ERR_OK) {
LogError << "TensorBaseMalloc failed, ret=" << ret << ".";
return ret;
}
outputs.push_back(tensor);
}
MxBase::DynamicInfo dynamicInfo = {};
dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH;
auto startTime = std::chrono::high_resolution_clock::now();
APP_ERROR ret = model_->ModelInference(inputs, outputs, dynamicInfo);
auto endTime = std::chrono::high_resolution_clock::now();
double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
inferCostTimeMilliSec += costMs;
if (ret != APP_ERR_OK) {
LogError << "ModelInference failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR Resnet18ClassifyOpencv::PostProcess(std::vector<TensorBase> &inputs,
std::vector<std::vector<ClassInfo>> &clsInfos) {
APP_ERROR ret = post_->Process(inputs, clsInfos);
if (ret != APP_ERR_OK) {
LogError << "Process failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR Resnet18ClassifyOpencv::SaveResult(const std::string &imgPath,
std::vector<std::vector<ClassInfo>> &batchClsInfos) {
LogInfo << "image path" << imgPath;
std::string fileName = imgPath.substr(imgPath.find_last_of("/") + 1);
size_t dot = fileName.find_last_of(".");
std::string resFileName = "result/" + fileName.substr(0, dot) + "_1.txt";
LogInfo << "file path for saving result" << resFileName;
std::ofstream outfile(resFileName);
if (outfile.fail()) {
LogError << "Failed to open result file: ";
return APP_ERR_COMM_FAILURE;
}
uint32_t batchIndex = 0;
for (auto clsInfos : batchClsInfos) {
std::string resultStr;
for (auto clsInfo : clsInfos) {
LogDebug << " className:" << clsInfo.className << " confidence:" << clsInfo.confidence <<
" classIndex:" << clsInfo.classId;
resultStr += std::to_string(clsInfo.classId) + " ";
}
outfile << resultStr << std::endl;
batchIndex++;
}
outfile.close();
return APP_ERR_OK;
}
APP_ERROR Resnet18ClassifyOpencv::Process(const std::string &imgPath) {
TensorBase tensorBase;
std::vector<TensorBase> inputs;
std::vector<TensorBase> outputs;
APP_ERROR ret = ConvertImageToTensorBase(imgPath, tensorBase);
if (ret != APP_ERR_OK) {
LogError << "Convert image to TensorBase failed, ret=" << ret << ".";
return ret;
}
inputs.push_back(tensorBase);
auto startTime = std::chrono::high_resolution_clock::now();
ret = Inference(inputs, outputs);
auto endTime = std::chrono::high_resolution_clock::now();
double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
inferCostTimeMilliSec += costMs;
if (ret != APP_ERR_OK) {
LogError << "Inference failed, ret=" << ret << ".";
return ret;
}
std::vector<std::vector<ClassInfo>> BatchClsInfos;
ret = PostProcess(outputs, BatchClsInfos);
if (ret != APP_ERR_OK) {
LogError << "PostProcess failed, ret=" << ret << ".";
return ret;
}
ret = SaveResult(imgPath, BatchClsInfos);
if (ret != APP_ERR_OK) {
LogError << "Save infer results into file failed. ret = " << ret << ".";
return ret;
}
return APP_ERR_OK;
}

View File

@ -0,0 +1,65 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MXBASE_RESNET18CLASSIFYOPENCV_H
#define MXBASE_RESNET18CLASSIFYOPENCV_H
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <opencv2/opencv.hpp>
#include "MxBase/DvppWrapper/DvppWrapper.h"
#include "MxBase/ModelInfer/ModelInferenceProcessor.h"
#include "ClassPostProcessors/Resnet50PostProcess.h"
#include "MxBase/Tensor/TensorContext/TensorContext.h"
struct InitParam {
uint32_t deviceId;
std::string labelPath;
uint32_t classNum;
uint32_t topk;
bool softmax;
bool checkTensor;
std::string modelPath;
};
class Resnet18ClassifyOpencv {
public:
APP_ERROR Init(const InitParam &initParam);
APP_ERROR DeInit();
APP_ERROR ConvertImageToTensorBase(const std::string &imgPath, MxBase::TensorBase &tensorBase);
APP_ERROR Inference(std::vector<MxBase::TensorBase> &inputs,
std::vector<MxBase::TensorBase> &outputs);
APP_ERROR PostProcess(std::vector<MxBase::TensorBase> &inputs,
std::vector<std::vector<MxBase::ClassInfo>> &clsInfos);
APP_ERROR Process(const std::string &imgPath);
// get infer time
double GetInferCostMilliSec() const {return inferCostTimeMilliSec;}
private:
APP_ERROR SaveResult(const std::string &imgPath,
std::vector<std::vector<MxBase::ClassInfo>> &batchClsInfos);
std::shared_ptr<MxBase::DvppWrapper> dvppWrapper_;
std::shared_ptr<MxBase::ModelInferenceProcessor> model_;
std::shared_ptr<MxBase::Resnet50PostProcess> post_;
MxBase::ModelDesc modelDesc_;
uint32_t deviceId_ = 0;
// infer time
double inferCostTimeMilliSec = 0.0;
};
#endif

View File

@ -0,0 +1,53 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
path_cur=$(dirname "$0")
function check_env()
{
# set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user
if [ ! "${ASCEND_VERSION}" ]; then
echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}"
else
echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user"
fi
if [ ! "${ARCH_PATTERN}" ]; then
# set ARCH_PATTERN to ./ when it was not specified by user
echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}"
else
echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user"
fi
}
function build_resnet18()
{
cd "$path_cur" || exit
rm -rf build
mkdir -p build
cd build || exit
cmake ..
make
ret=$?
if [ ${ret} -ne 0 ]; then
echo "Failed to build resnet18."
exit ${ret}
fi
make install
}
check_env
build_resnet18

View File

@ -0,0 +1,90 @@
/*
* Copyright (c) 2021. Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <dirent.h>
#include "Resnet18ClassifyOpencv.h"
#include "MxBase/Log/Log.h"
namespace {
const uint32_t CLASS_NUM = 1001;
}
APP_ERROR ReadFilesFromPath(const std::string &path, std::vector<std::string> *files) {
DIR *dir = NULL;
struct dirent *ptr = NULL;
if ((dir=opendir(path.c_str())) == NULL) {
LogError << "Open dir error: " << path;
return APP_ERR_COMM_OPEN_FAIL;
}
while ((ptr=readdir(dir)) != NULL) {
// d_type == 8 is file
if (ptr->d_type == 8) {
files->push_back(path + ptr->d_name);
}
}
closedir(dir);
// sort ascending order
sort(files->begin(), files->end());
return APP_ERR_OK;
}
int main(int argc, char* argv[]) {
if (argc <= 1) {
LogWarn << "Please input image path, such as './resnet image_dir'.";
return APP_ERR_OK;
}
InitParam initParam = {};
initParam.deviceId = 0;
initParam.classNum = CLASS_NUM;
initParam.labelPath = "../data/config/imagenet1000_clsidx_to_labels.names";
initParam.topk = 5;
initParam.softmax = false;
initParam.checkTensor = true;
initParam.modelPath = "../data/model/resnet18-304_304.om";
auto resnet18 = std::make_shared<Resnet18ClassifyOpencv>();
APP_ERROR ret = resnet18->Init(initParam);
if (ret != APP_ERR_OK) {
LogError << "resnet18Classify init failed, ret=" << ret << ".";
return ret;
}
std::string inferPath = argv[1];
std::vector<std::string> files;
ret = ReadFilesFromPath(inferPath, &files);
if (ret != APP_ERR_OK) {
LogError << "Read files from path failed, ret=" << ret << ".";
return ret;
}
auto startTime = std::chrono::high_resolution_clock::now();
for (uint32_t i = 0; i < files.size(); i++) {
ret = resnet18->Process(files[i]);
if (ret != APP_ERR_OK) {
LogError << "resnet18Classify process failed, ret=" << ret << ".";
resnet18->DeInit();
return ret;
}
}
auto endTime = std::chrono::high_resolution_clock::now();
resnet18->DeInit();
double costMilliSecs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
double fps = 1000.0 * files.size() / resnet18->GetInferCostMilliSec();
LogInfo << "[Process Delay] cost: " << costMilliSecs << " ms\tfps: " << fps << " imgs/sec";
return APP_ERR_OK;
}

View File

@ -0,0 +1,101 @@
# coding=utf-8
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import datetime
import json
import os
import sys
from StreamManagerApi import StreamManagerApi
from StreamManagerApi import MxDataInput
def run():
# init stream manager
stream_manager_api = StreamManagerApi()
ret = stream_manager_api.InitManager()
if ret != 0:
print("Failed to init Stream manager, ret=%s" % str(ret))
return
# create streams by pipeline config file
with open("./resnet18.pipeline", 'rb') as f:
pipelineStr = f.read()
ret = stream_manager_api.CreateMultipleStreams(pipelineStr)
if ret != 0:
print("Failed to create Stream, ret=%s" % str(ret))
return
# Construct the input of the stream
data_input = MxDataInput()
dir_name = sys.argv[1]
res_dir_name = sys.argv[2]
file_list = os.listdir(dir_name)
if not os.path.exists(res_dir_name):
os.makedirs(res_dir_name)
for file_name in file_list:
file_path = os.path.join(dir_name, file_name)
if not (file_name.lower().endswith(
".jpg") or file_name.lower().endswith(".jpeg")):
continue
with open(file_path, 'rb') as f:
data_input.data = f.read()
stream_name = b'im_resnet18'
in_plugin_id = 0
unique_id = stream_manager_api.SendData(stream_name, in_plugin_id,
data_input)
if unique_id < 0:
print("Failed to send data to stream.")
return
# Obtain the inference result by specifying streamName and uniqueId.
start_time = datetime.datetime.now()
infer_result = stream_manager_api.GetResult(stream_name, unique_id)
end_time = datetime.datetime.now()
print('sdk run time: {}'.format((end_time - start_time).microseconds))
if infer_result.errorCode != 0:
print("GetResultWithUniqueId error. errorCode=%d, errorMsg=%s" % (
infer_result.errorCode, infer_result.data.decode()))
return
# print the infer result
infer_res = infer_result.data.decode()
print("process img: {}, infer result: {}".format(file_name, infer_res))
load_dict = json.loads(infer_result.data.decode())
if load_dict.get('MxpiClass') is None:
with open(res_dir_name + "/" + file_name[:-5] + '.txt',
'w') as f_write:
f_write.write("")
continue
res_vec = load_dict.get('MxpiClass')
with open(res_dir_name + "/" + file_name[:-5] + '_1.txt',
'w') as f_write:
res_list = [str(item.get("classId")) + " " for item in res_vec]
f_write.writelines(res_list)
f_write.write('\n')
# destroy streams
stream_manager_api.DestroyAllStreams()
if __name__ == '__main__':
run()

View File

@ -0,0 +1,177 @@
# coding = utf-8
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the BSD 3-Clause License (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import json
import numpy as np
np.set_printoptions(threshold=sys.maxsize)
LABEL_FILE = "HiAI_label.json"
def gen_file_name(img_name):
full_name = img_name.split('/')[-1]
return os.path.splitext(full_name)
def cre_groundtruth_dict(gtfile_path):
"""
:param filename: file contains the imagename and label number
:return: dictionary key imagename, value is label number
"""
img_gt_dict = {}
for gtfile in os.listdir(gtfile_path):
if gtfile != LABEL_FILE:
with open(os.path.join(gtfile_path, gtfile), 'r') as f:
gt = json.load(f)
ret = gt["image"]["annotations"][0]["category_id"]
img_gt_dict[gen_file_name(gtfile)] = ret
return img_gt_dict
def cre_groundtruth_dict_fromtxt(gtfile_path):
"""
:param filename: file contains the imagename and label number
:return: dictionary key imagename, value is label number
"""
img_gt_dict = {}
with open(gtfile_path, 'r')as f:
for line in f.readlines():
temp = line.strip().split(" ")
img_name = temp[0].split(".")[0]
img_lab = temp[1]
img_gt_dict[img_name] = img_lab
return img_gt_dict
def load_statistical_predict_result(filepath):
"""
function:
the prediction esult file data extraction
input:
result file:filepath
output:
n_label:numble of label
data_vec: the probabilitie of prediction in the 1000
:return: probabilities, numble of label, in_type, color
"""
with open(filepath, 'r')as f:
data = f.readline()
temp = data.strip().split(" ")
n_label = len(temp)
data_vec = np.zeros((n_label), dtype=np.float32)
in_type = ''
color = ''
if n_label == 0:
in_type = f.readline()
color = f.readline()
else:
for ind, cls_ind in enumerate(temp):
data_vec[ind] = np.int32(cls_ind)
return data_vec, n_label, in_type, color
def create_visualization_statistical_result(prediction_file_path,
result_store_path, file_name,
img_gt_dict, topn=5):
"""
:param prediction_file_path:
:param result_store_path:
:param file_name:
:param img_gt_dict:
:param topn:
:return:
"""
writer = open(os.path.join(result_store_path, file_name), 'w')
table_dict = {"title": "Overall statistical evaluation", "value": []}
count = 0
res_cnt = 0
n_labels = ""
count_hit = np.zeros(topn)
for tfile_name in os.listdir(prediction_file_path):
count += 1
temp = tfile_name.split('.')[0]
index = temp.rfind('_')
img_name = temp[:index]
filepath = os.path.join(prediction_file_path, tfile_name)
ret = load_statistical_predict_result(filepath)
prediction = ret[0]
n_labels = ret[1]
gt = img_gt_dict[img_name]
if n_labels == 1000:
real_label = int(gt)
elif n_labels == 1001:
real_label = int(gt) + 1
else:
real_label = int(gt)
res_cnt = min(len(prediction), topn)
for i in range(res_cnt):
if str(real_label) == str(int(prediction[i])):
count_hit[i] += 1
break
if 'value' not in table_dict.keys():
print("the item value does not exist!")
else:
table_dict["value"].extend(
[{"key": "Number of images", "value": str(count)},
{"key": "Number of classes", "value": str(n_labels)}])
if count == 0:
accuracy = 0
else:
accuracy = np.cumsum(count_hit) / count
for i in range(res_cnt):
table_dict["value"].append({"key": "Top" + str(i + 1) + " accuracy",
"value": str(
round(accuracy[i] * 100, 2)) + '%'})
json.dump(table_dict, writer)
writer.close()
if __name__ == '__main__':
try:
# txt file path
folder_davinci_target = sys.argv[1]
# annotation files path, "val_label.txt"
annotation_file_path = sys.argv[2]
# the path to store the results json path
result_json_path = sys.argv[3]
# result json file name
json_file_name = sys.argv[4]
except IndexError:
print("Please enter target file result folder | ground truth label file"
"| result json file folder | "
"result json file name, such as "
"./result val_label.txt . result.json")
exit(1)
if not os.path.exists(folder_davinci_target):
print("Target file folder does not exist.")
if not os.path.exists(annotation_file_path):
print("Ground truth file does not exist.")
if not os.path.exists(result_json_path):
print("Result folder doesn't exist.")
img_label_dict = cre_groundtruth_dict_fromtxt(annotation_file_path)
create_visualization_statistical_result(folder_davinci_target,
result_json_path, json_file_name,
img_label_dict, topn=5)

View File

@ -0,0 +1 @@
{"title": "Overall statistical evaluation", "value": [{"key": "Number of images", "value": "1000"}, {"key": "Number of classes", "value": "5"}, {"key": "Top1 accuracy", "value": "69.3%"}, {"key": "Top2 accuracy", "value": "80.1%"}, {"key": "Top3 accuracy", "value": "84.2%"}, {"key": "Top4 accuracy", "value": "86.2%"}, {"key": "Top5 accuracy", "value": "87.8%"}]}

View File

@ -0,0 +1,64 @@
{
"im_resnet18": {
"stream_config": {
"deviceId": "0"
},
"appsrc1": {
"props": {
"blocksize": "409600"
},
"factory": "appsrc",
"next": "mxpi_imagedecoder0"
},
"mxpi_imagedecoder0": {
"props": {
"handleMethod": "opencv"
},
"factory": "mxpi_imagedecoder",
"next": "mxpi_imageresize0"
},
"mxpi_imageresize0": {
"props": {
"handleMethod": "opencv",
"resizeType": "Resizer_Stretch",
"resizeHeight": "304",
"resizeWidth": "304"
},
"factory": "mxpi_imageresize",
"next": "mxpi_tensorinfer0"
},
"mxpi_tensorinfer0": {
"props": {
"dataSource": "mxpi_imageresize0",
"modelPath": "../data/model/resnet18-304_304.om",
"waitingTime": "2000",
"outputDeviceId": "-1"
},
"factory": "mxpi_tensorinfer",
"next": "mxpi_classpostprocessor0"
},
"mxpi_classpostprocessor0": {
"props": {
"dataSource": "mxpi_tensorinfer0",
"postProcessConfigPath": "../data/config/resnet18.cfg",
"labelPath": "../data/config/imagenet1000_clsidx_to_labels.names",
"postProcessLibPath": "../../../lib/modelpostprocessors/libresnet50postprocess.so"
},
"factory": "mxpi_classpostprocessor",
"next": "mxpi_dataserialize0"
},
"mxpi_dataserialize0": {
"props": {
"outputDataKeys": "mxpi_classpostprocessor0"
},
"factory": "mxpi_dataserialize",
"next": "appsink0"
},
"appsink0": {
"props": {
"blocksize": "4096000"
},
"factory": "appsink"
}
}
}

View File

@ -0,0 +1,32 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
image_path=$1
result_dir=$2
set -e
# Simple log helper functions
info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; }
warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; }
export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner
export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins
#to set PYTHONPATH, import the StreamManagerApi.py
export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python
python3.7 main.py $image_path $result_dir
exit 0

View File

@ -0,0 +1,440 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train resnet."""
import os
import argparse
import ast
import moxing as mox
import numpy as np
from mindspore import context
from mindspore import export
from mindspore import Tensor
from mindspore.nn.optim import Momentum, thor
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.train_thor import ConvertModelUtils
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.common import set_seed
from mindspore.parallel import set_algo_parameters
import mindspore.nn as nn
import mindspore.common.initializer as weight_init
import mindspore.log as logger
from src.lr_generator import get_lr, warmup_cosine_annealing_lr, get_resnet34_lr
from src.CrossEntropySmooth import CrossEntropySmooth
from src.config import cfg
from src.eval_callback import EvalCallBack
from src.metric import DistAccuracy, ClassifyCorrectCell
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--train_url', type=str, default='',
help='the path model saved')
parser.add_argument('--data_url', type=str, default='',
help='the training data')
parser.add_argument('--net', type=str, default="resnet18",
help='Resnet Model, resnet18, resnet34, '
'resnet50 or resnet101')
parser.add_argument('--dataset', type=str, default="imagenet2012",
help='Dataset, either cifar10 or imagenet2012')
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False,
help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--dataset_path', type=str, default="/cache",
help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend',
choices=("Ascend", "GPU", "CPU"),
help="Device target, support Ascend, GPU and CPU.")
parser.add_argument('--pre_trained', type=str, default=None,
help='Pretrained checkpoint path')
parser.add_argument('--parameter_server', type=ast.literal_eval, default=False,
help='Run parameter server train')
parser.add_argument("--filter_weight", type=ast.literal_eval, default=False,
help="Filter head weight parameters, default is False.")
parser.add_argument("--run_eval", type=ast.literal_eval, default=False,
help="Run evaluation when training, default is False.")
parser.add_argument('--eval_dataset_path', type=str, default=None,
help='Evaluation dataset path when run_eval is True')
parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True,
help="Save best checkpoint when run_eval is True, "
"default is True.")
parser.add_argument("--eval_start_epoch", type=int, default=40,
help="Evaluation start epoch when run_eval is True, "
"default is 40.")
parser.add_argument("--eval_interval", type=int, default=1,
help="Evaluation interval when run_eval is True, "
"default is 1.")
parser.add_argument('--enable_cache', type=ast.literal_eval, default=False,
help='Caching the eval dataset in memory to speedup '
'evaluation, default is False.')
parser.add_argument('--cache_session_id', type=str, default="",
help='The session id for cache service.')
parser.add_argument('--mode', type=str, default='GRAPH',
choices=('GRAPH', 'PYNATIVE'),
help="Graph mode or PyNative mode, default is Graph mode")
parser.add_argument("--epoch_size", type=int, default=1,
help="training epoch size, default is 1.")
parser.add_argument("--num_classes", type=int, default=1001,
help="number of dataset categories, default is 1001.")
args_opt = parser.parse_args()
CKPT_OUTPUT_PATH = "./"
set_seed(1)
if args_opt.net in ("resnet18", "resnet50"):
if args_opt.net == "resnet18":
from src.resnet import resnet18 as resnet
if args_opt.net == "resnet50":
from src.resnet import resnet50 as resnet
if args_opt.dataset == "cifar10":
from src.config import config1 as config
from src.dataset import create_dataset1 as create_dataset
else:
from src.config import config2 as config
if args_opt.mode == "GRAPH":
from src.dataset import create_dataset2 as create_dataset
else:
from src.dataset import create_dataset_pynative as create_dataset
elif args_opt.net == "resnet34":
from src.resnet import resnet34 as resnet
from src.config import config_resnet34 as config
from src.dataset import create_dataset_resnet34 as create_dataset
elif args_opt.net == "resnet101":
from src.resnet import resnet101 as resnet
from src.config import config3 as config
from src.dataset import create_dataset3 as create_dataset
else:
from src.resnet import se_resnet50 as resnet
from src.config import config4 as config
from src.dataset import create_dataset4 as create_dataset
if cfg.optimizer == "Thor":
if args_opt.device_target == "Ascend":
from src.config import config_thor_Ascend as config
else:
from src.config import config_thor_gpu as config
def filter_checkpoint_parameter_by_list(origin_dict, param_filter):
"""remove useless parameters according to filter_list"""
for key in list(origin_dict.keys()):
for name in param_filter:
if name in key:
print("Delete parameter from checkpoint: ", key)
del origin_dict[key]
break
def apply_eval(eval_param):
eval_model = eval_param["model"]
eval_ds = eval_param["dataset"]
metrics_name = eval_param["metrics_name"]
res = eval_model.eval(eval_ds)
return res[metrics_name]
def set_graph_kernel_context(run_platform, net_name):
if run_platform == "GPU" and net_name == "resnet101":
context.set_context(enable_graph_kernel=True,
graph_kernel_flags="--enable_parallel_fusion")
def _get_last_ckpt(ckpt_dir):
ckpt_files = [ckpt_file for ckpt_file in os.listdir(ckpt_dir)
if ckpt_file.endswith('.ckpt')]
if not ckpt_files:
print("No ckpt file found.")
return None
return os.path.join(ckpt_dir, sorted(ckpt_files)[-1])
def _export_air(ckpt_dir):
ckpt_file = _get_last_ckpt(ckpt_dir)
if not ckpt_file:
return
net = resnet(config.class_num)
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.zeros([1, 3, 304, 304],
np.float32))
export(net, input_arr, file_name="resnet",
file_format="AIR")
def set_config():
config.epoch_size = args_opt.epoch_size
config.num_classes = args_opt.num_classes
def init_context(target):
if args_opt.mode == 'GRAPH':
context.set_context(mode=context.GRAPH_MODE, device_target=target,
save_graphs=False)
set_graph_kernel_context(target, args_opt.net)
else:
context.set_context(mode=context.PYNATIVE_MODE, device_target=target,
save_graphs=False)
if args_opt.parameter_server:
context.set_ps_context(enable_ps=True)
if args_opt.run_distribute:
if target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id,
enable_auto_mixed_precision=True)
context.set_auto_parallel_context(
device_num=args_opt.device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
set_algo_parameters(elementwise_op_strategy_follow=True)
if args_opt.net == "resnet50" or args_opt.net == "se-resnet50":
context.set_auto_parallel_context(
all_reduce_fusion_config=[85, 160])
elif args_opt.net == "resnet101":
context.set_auto_parallel_context(
all_reduce_fusion_config=[80, 210, 313])
init()
# GPU target
else:
init()
context.set_auto_parallel_context(
device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
if args_opt.net == "resnet50":
context.set_auto_parallel_context(
all_reduce_fusion_config=[85, 160])
def init_weight(net):
if os.path.exists(args_opt.pre_trained):
param_dict = load_checkpoint(args_opt.pre_trained)
if args_opt.filter_weight:
filter_list = [x.name for x in net.end_point.get_parameters()]
filter_checkpoint_parameter_by_list(param_dict, filter_list)
load_param_into_net(net, param_dict)
else:
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(
weight_init.initializer(weight_init.XavierUniform(),
cell.weight.shape,
cell.weight.dtype))
if isinstance(cell, nn.Dense):
cell.weight.set_data(
weight_init.initializer(weight_init.TruncatedNormal(),
cell.weight.shape,
cell.weight.dtype))
def init_lr(step_size):
if cfg.optimizer == "Thor":
from src.lr_generator import get_thor_lr
lr = get_thor_lr(0, config.lr_init, config.lr_decay,
config.lr_end_epoch, step_size, decay_epochs=39)
else:
if args_opt.net in ("resnet18", "resnet34", "resnet50", "se-resnet50"):
lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end,
lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs,
total_epochs=config.epoch_size,
steps_per_epoch=step_size,
lr_decay_mode=config.lr_decay_mode)
else:
lr = warmup_cosine_annealing_lr(
config.lr, step_size, config.warmup_epochs, config.epoch_size,
config.pretrain_epoch_size * step_size)
if args_opt.net == "resnet34":
lr = get_resnet34_lr(lr_init=config.lr_init,
lr_end=config.lr_end,
lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs,
total_epochs=config.epoch_size,
steps_per_epoch=step_size)
return Tensor(lr)
def define_opt(net, lr):
decayed_params = []
no_decayed_params = []
for param in net.trainable_params():
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' \
not in param.name:
decayed_params.append(param)
else:
no_decayed_params.append(param)
group_params = [
{'params': decayed_params, 'weight_decay': config.weight_decay},
{'params': no_decayed_params},
{'order_params': net.trainable_params()}]
opt = Momentum(group_params, lr, config.momentum,
loss_scale=config.loss_scale)
return opt
def define_model(net, opt, target):
if args_opt.dataset == "imagenet2012":
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=config.label_smooth_factor,
num_classes=config.class_num)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss_scale = FixedLossScaleManager(config.loss_scale,
drop_overflow_update=False)
dist_eval_network = ClassifyCorrectCell(
net) if args_opt.run_distribute else None
metrics = {"acc"}
if args_opt.run_distribute:
metrics = {'acc': DistAccuracy(batch_size=config.batch_size,
device_num=args_opt.device_num)}
if (args_opt.net not in ("resnet18", "resnet50", "resnet101",
"se-resnet50")) or args_opt.parameter_server \
or target == "CPU":
# fp32 training
model = Model(net, loss_fn=loss, optimizer=opt, metrics=metrics,
eval_network=dist_eval_network)
else:
model = Model(net, loss_fn=loss, optimizer=opt,
loss_scale_manager=loss_scale, metrics=metrics,
amp_level="O2", keep_batchnorm_fp32=False,
eval_network=dist_eval_network)
return model, loss, loss_scale
def run_eval(model, target, ckpt_save_dir):
if args_opt.eval_dataset_path is None \
or (not os.path.isdir(args_opt.eval_dataset_path)):
raise ValueError(
"{} is not a existing path.".format(args_opt.eval_dataset_path))
eval_dataset = create_dataset(
dataset_path=args_opt.eval_dataset_path,
do_train=False,
batch_size=config.batch_size,
target=target,
enable_cache=args_opt.enable_cache,
cache_session_id=args_opt.cache_session_id)
eval_param_dict = {"model": model, "dataset": eval_dataset,
"metrics_name": "acc"}
eval_cb = EvalCallBack(apply_eval, eval_param_dict,
interval=args_opt.eval_interval,
eval_start_epoch=args_opt.eval_start_epoch,
save_best_ckpt=args_opt.save_best_ckpt,
ckpt_directory=ckpt_save_dir,
besk_ckpt_name="best_acc.ckpt",
metrics_name="acc")
return eval_cb
def main():
set_config()
target = args_opt.device_target
if target == "CPU":
args_opt.run_distribute = False
# init context
init_context(target)
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(
get_rank()) + "/"
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True,
repeat_num=1,
batch_size=config.batch_size, target=target,
distribute=args_opt.run_distribute)
step_size = dataset.get_dataset_size()
# define net
net = resnet(class_num=config.class_num)
if args_opt.parameter_server:
net.set_param_ps()
# init weight
init_weight(net)
# init lr
lr = init_lr(step_size)
# define opt
opt = define_opt(net, lr)
# define model
model, loss, loss_scale = define_model(net, opt, target)
if cfg.optimizer == "Thor" and args_opt.dataset == "imagenet2012":
from src.lr_generator import get_thor_damping
damping = get_thor_damping(0, config.damping_init, config.damping_decay,
70, step_size)
split_indices = [26, 53]
opt = thor(net, lr, Tensor(damping), config.momentum,
config.weight_decay, config.loss_scale,
config.batch_size, split_indices=split_indices,
frequency=config.frequency)
model = ConvertModelUtils().convert_to_thor_model(
model=model, network=net, loss_fn=loss, optimizer=opt,
loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O2",
keep_batchnorm_fp32=False)
args_opt.run_eval = False
logger.warning("Thor optimizer not support evaluation while training.")
# define callbacks
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
if config.save_checkpoint:
config_ck = CheckpointConfig(
save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir,
config=config_ck)
cb += [ckpt_cb]
if args_opt.run_eval:
eval_cb = run_eval(model, target, ckpt_save_dir)
cb += [eval_cb]
# train model
if args_opt.net == "se-resnet50":
config.epoch_size = config.train_epoch_size
dataset_sink_mode = (not args_opt.parameter_server) and target != "CPU"
model.train(config.epoch_size - config.pretrain_epoch_size, dataset,
callbacks=cb,
sink_size=dataset.get_dataset_size(),
dataset_sink_mode=dataset_sink_mode)
if args_opt.run_eval and args_opt.enable_cache:
print(
"Remember to shut down the cache server via \"cache_admin --stop\"")
if __name__ == '__main__':
# 将数据集拷贝到ModelArts指定读取的cache目录
mox.file.copy_parallel(args_opt.data_url, '/cache')
main()
# 训练完成后把生成的模型拷贝到指导输出目录
if not os.path.exists(CKPT_OUTPUT_PATH):
os.makedirs(CKPT_OUTPUT_PATH, exist_ok=True)
_export_air(CKPT_OUTPUT_PATH)
mox.file.copy_parallel(CKPT_OUTPUT_PATH, args_opt.train_url)