modify again based on retest info

This commit is contained in:
Dai Surong 2021-08-30 14:59:16 +08:00
parent c3211bb9ef
commit bf41dba0d9
4 changed files with 11 additions and 19 deletions

View File

@ -20,7 +20,6 @@
#include <iostream> #include <iostream>
#include <opencv2/dnn.hpp> #include <opencv2/dnn.hpp>
using namespace MxBase; using namespace MxBase;
using namespace cv::dnn; using namespace cv::dnn;
namespace { namespace {
@ -105,21 +104,18 @@ APP_ERROR DPN::ResizeImage(const cv::Mat &srcImageMat, cv::Mat &dstImageMat)
APP_ERROR DPN::CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase &tensorBase) APP_ERROR DPN::CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase &tensorBase)
{ {
uint32_t dataSize=1; uint32_t dataSize=1;
for (size_t i=0; i<modelDesc_.inputTensors.size(); ++i) {
for (size_t i = 0; i < modelDesc_.inputTensors.size(); ++i) {
std::vector<uint32_t> shape = {}; std::vector<uint32_t> shape = {};
for (size_t j = 0; j < modelDesc_.inputTensors[i].tensorDims.size(); ++j) { for (size_t j = 0; j < modelDesc_.inputTensors[i].tensorDims.size(); ++j) {
shape.push_back((uint32_t)modelDesc_.inputTensors[i].tensorDims[j]); shape.push_back((uint32_t)modelDesc_.inputTensors[i].tensorDims[j]);
} }
for(uint32_t i = 0; i < shape.size(); ++i){ for(uint32_t i=0; i<shape.size(); ++i){
dataSize *= shape[i]; dataSize *= shape[i];
} }
std::cout<< std::endl;
} }
// mat NCHW to NHWC // mat NCHW to NHWC
size_t N=32, H=224, W=224, C=3; size_t N=32,H=224,W=224,C=3;
unsigned char *mat_data = new unsigned char[dataSize]; unsigned char *mat_data = new unsigned char[dataSize];
uint32_t idx=0; uint32_t idx=0;
for(size_t n=0; n<N; n++){ for(size_t n=0; n<N; n++){
@ -190,7 +186,8 @@ APP_ERROR DPN::PostProcess(const std::vector<MxBase::TensorBase> &inputs,
return APP_ERR_OK; return APP_ERR_OK;
} }
APP_ERROR DPN::SaveResult(const std::vector<std::string> &batchImgPaths, const std::vector<std::vector<MxBase::ClassInfo>> &batchClsInfos) APP_ERROR DPN::SaveResult(const std::vector<std::string> &batchImgPaths,
const std::vector<std::vector<MxBase::ClassInfo>> &batchClsInfos)
{ {
uint32_t batchIndex = 0; uint32_t batchIndex = 0;
for(auto &imgPath: batchImgPaths){ for(auto &imgPath: batchImgPaths){

View File

@ -61,5 +61,4 @@ class DPN {
double inferCostTimeMilliSec = 0.0; double inferCostTimeMilliSec = 0.0;
}; };
#endif #endif

View File

@ -43,7 +43,6 @@ APP_ERROR ScanImages(const std::string &path, std::vector<std::string> &imgFiles
return APP_ERR_OK; return APP_ERR_OK;
} }
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
if (argc <= 1) { if (argc <= 1) {
@ -75,7 +74,7 @@ int main(int argc, char* argv[])
auto startTime = std::chrono::high_resolution_clock::now(); auto startTime = std::chrono::high_resolution_clock::now();
int inferImgsCount = 0; int inferImgsCount = 0;
LogInfo << "Number of total images load from input data path: " << imgFilePaths.size(); LogInfo << "Number of total images load from input data path: " << imgFilePaths.size();
for(uint32_t i=0; i<=imgFilePaths.size()-BATCH_SIZE; i+=BATCH_SIZE){ for(uint32_t i = 0; i <= imgFilePaths.size() - BATCH_SIZE; i += BATCH_SIZE){
std::vector<std::string>batchImgFilePaths(imgFilePaths.begin()+i, imgFilePaths.begin()+(i+BATCH_SIZE)); std::vector<std::string>batchImgFilePaths(imgFilePaths.begin()+i, imgFilePaths.begin()+(i+BATCH_SIZE));
ret = dpn->Process(batchImgFilePaths); ret = dpn->Process(batchImgFilePaths);
if (ret != APP_ERR_OK) { if (ret != APP_ERR_OK) {

View File

@ -160,7 +160,6 @@ def filter_weight_by_list(origin_dict, param_filter):
break break
def dpn_train(config_args, ma_config): def dpn_train(config_args, ma_config):
# create dataset
ma_config["training_data"] = config_args.data_path + "/train" ma_config["training_data"] = config_args.data_path + "/train"
ma_config["image_size"] = [config_args.image_size_height, config_args.image_size_width] ma_config["image_size"] = [config_args.image_size_height, config_args.image_size_width]
train_dataset = classification_dataset(ma_config["training_data"], train_dataset = classification_dataset(ma_config["training_data"],
@ -269,13 +268,12 @@ def dpn_train(config_args, ma_config):
return 0 return 0
def dpn_export(config_args, ma_config): def dpn_export(config_args, ma_config):
# define net
backbone = config_args.backbone backbone = config_args.backbone
num_classes = config_args.num_classes num_classes = config_args.num_classes
net = dpns[backbone](num_classes=num_classes) net = dpns[backbone](num_classes=num_classes)
# load checkpoint # load checkpoint
prob_ckpt_list = os.path.join(ma_config["checkpoint_path"] , "dpn*.ckpt") prob_ckpt_list = os.path.join(ma_config["checkpoint_path"], "dpn*.ckpt")
ckpt_list = glob.glob(prob_ckpt_list) ckpt_list = glob.glob(prob_ckpt_list)
if not ckpt_list: if not ckpt_list:
print('Freezing model failed!') print('Freezing model failed!')
@ -299,18 +297,17 @@ def dpn_export(config_args, ma_config):
def main(): def main():
# parser arguments
config_args = _parse_args() config_args = _parse_args()
# create local path # create local path
if not os.path.exists(config_args.data_path): if not os.path.exists(config_args.data_path):
os.makedirs(config_args.data_path, exist_ok=True) os.makedirs(config_args.data_path, exist_ok=True)
if not os.path.exists(config_args.output_path): if not os.path.exists(config_args.output_path):
os.makedirs(config_args.output_path, exist_ok=True) os.makedirs(config_args.output_path, exist_ok=True)
ma_config = {} ma_config = {}
# init context # init context
ma_config["checkpoint_path"] = os.path.join(config_args.output_path, config_args.checkpoint_dir) ma_config["checkpoint_path"] = os.path.join(config_args.output_path, config_args.checkpoint_dir)
if not os.path.exists(ma_config["checkpoint_path"]): if not os.path.exists(ma_config["checkpoint_path"]):
os.makedirs(ma_config["checkpoint_path"], exist_ok=True) os.makedirs(ma_config["checkpoint_path"], exist_ok=True)
ma_config["device_id"] = get_device_id() ma_config["device_id"] = get_device_id()
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target=config_args.device_target, save_graphs=False, device_id=ma_config["device_id"]) device_target=config_args.device_target, save_graphs=False, device_id=ma_config["device_id"])