diff --git a/model_zoo/official/cv/crnn/ascend310_infer/src/main.cc b/model_zoo/official/cv/crnn/ascend310_infer/src/main.cc index a815c4b95c5..daf1c1813e1 100644 --- a/model_zoo/official/cv/crnn/ascend310_infer/src/main.cc +++ b/model_zoo/official/cv/crnn/ascend310_infer/src/main.cc @@ -33,12 +33,12 @@ #include "inc/utils.h" -using mindspore::GlobalContext; +using mindspore::Context; using mindspore::Serialization; using mindspore::Model; -using mindspore::ModelContext; using mindspore::Status; using mindspore::ModelType; +using mindspore::Graph; using mindspore::GraphCell; using mindspore::kSuccess; using mindspore::MSTensor; @@ -64,21 +64,28 @@ int main(int argc, char **argv) { return 1; } - GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310); - GlobalContext::SetGlobalDeviceID(FLAGS_device_id); - auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR); - auto model_context = std::make_shared(); - if (!FLAGS_aipp_path.empty()) { - ModelContext::SetInsertOpConfigPath(model_context, FLAGS_aipp_path); + auto context = std::make_shared(); + auto ascend310_info = std::make_shared(); + ascend310_info->SetDeviceID(FLAGS_device_id); + ascend310_info->SetInsertOpConfigPath({FLAGS_aipp_path}); + context->MutableDeviceInfo().push_back(ascend310_info); + + Graph graph; + Status ret = Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph); + if (ret != kSuccess) { + std::cout << "Load model failed." << std::endl; + return 1; } - Model model(GraphCell(graph), model_context); - Status ret = model.Build(); + Model model; + ret = model.Build(GraphCell(graph), context); if (ret != kSuccess) { std::cout << "ERROR: Build failed." << std::endl; return 1; } + std::vector modelInputs = model.GetInputs(); + auto allFiles = GetAllFiles(FLAGS_dataset_path); if (allFiles.empty()) { std::cout << "ERROR: no input data." << std::endl; @@ -108,11 +115,12 @@ int main(int argc, char **argv) { std::cout << "wrong file format: " << allFiles[i] << std::endl; continue; } - auto img = std::make_shared(); - compose(ReadFileToTensor(allFiles[i]), img.get()); - inputs.emplace_back(img->Name(), img->DataType(), img->Shape(), - img->Data().get(), img->DataSize()); + mindspore::MSTensor img; + compose(ReadFileToTensor(allFiles[i]), &img); + + inputs.emplace_back(modelInputs[0].Name(), modelInputs[0].DataType(), modelInputs[0].Shape(), + img.Data().get(), img.DataSize()); gettimeofday(&start, NULL); ret = model.Predict(inputs, &outputs); diff --git a/model_zoo/official/cv/ctpn/ascend310_infer/src/main.cc b/model_zoo/official/cv/ctpn/ascend310_infer/src/main.cc index 4c20c852731..5c07a3c8537 100644 --- a/model_zoo/official/cv/ctpn/ascend310_infer/src/main.cc +++ b/model_zoo/official/cv/ctpn/ascend310_infer/src/main.cc @@ -34,12 +34,12 @@ #include "include/api/serialization.h" #include "include/api/context.h" -using mindspore::GlobalContext; using mindspore::Serialization; using mindspore::Model; -using mindspore::ModelContext; +using mindspore::Context; using mindspore::Status; using mindspore::ModelType; +using mindspore::Graph; using mindspore::GraphCell; using mindspore::kSuccess; using mindspore::MSTensor; @@ -71,18 +71,27 @@ int main(int argc, char **argv) { return 1; } - GlobalContext::SetGlobalDeviceTarget(FLAGS_device_target); - GlobalContext::SetGlobalDeviceID(FLAGS_device_id); + auto context = std::make_shared(); + auto ascend310_info = std::make_shared(); + ascend310_info->SetDeviceID(FLAGS_device_id); + context->MutableDeviceInfo().push_back(ascend310_info); - auto graph = Serialization::LoadModel(FLAGS_model_path, ModelType::kMindIR); - - Model model((GraphCell(graph))); - Status ret = model.Build(); + Graph graph; + Status ret = Serialization::Load(FLAGS_model_path, ModelType::kMindIR, &graph); if (ret != kSuccess) { - std::cout << "ERROR Build failed." << std::endl; + std::cout << "Load model failed." << std::endl; return 1; } + Model model; + ret = model.Build(GraphCell(graph), context); + if (ret != kSuccess) { + std::cout << "ERROR: Build failed." << std::endl; + return 1; + } + + std::vector modelInputs = model.GetInputs(); + auto all_files = GetAllFiles(FLAGS_dataset_path); if (all_files.empty()) { std::cout << "ERROR: no input data." << std::endl; @@ -118,7 +127,8 @@ int main(int argc, char **argv) { transform(image, &image); transformCast(image, &image); - inputs.emplace_back(image); + inputs.emplace_back(modelInputs[0].Name(), modelInputs[0].DataType(), modelInputs[0].Shape(), + image.Data().get(), image.DataSize()); gettimeofday(&start, NULL); model.Predict(inputs, &outputs); diff --git a/model_zoo/official/cv/faster_rcnn/ascend310_infer/src/AclProcess.cpp b/model_zoo/official/cv/faster_rcnn/ascend310_infer/src/AclProcess.cpp index ff0d7f75a17..a70448e16e5 100755 --- a/model_zoo/official/cv/faster_rcnn/ascend310_infer/src/AclProcess.cpp +++ b/model_zoo/official/cv/faster_rcnn/ascend310_infer/src/AclProcess.cpp @@ -165,6 +165,10 @@ int AclProcess::WriteResult(const std::string& imageFile) { std::string outFileName = homePath + "/" + fileName; try { FILE * outputFile = fopen(outFileName.c_str(), "wb"); + if (outputFile == nullptr) { + std::cout << "open result file " << outFileName << " failed" << std::endl; + return INVALID_POINTER; + } fwrite(resHostBuf, output_size, sizeof(char), outputFile); fclose(outputFile); outputFile = nullptr; diff --git a/model_zoo/official/cv/faster_rcnn/ascend310_infer/src/main.cpp b/model_zoo/official/cv/faster_rcnn/ascend310_infer/src/main.cpp index c2c998a011b..63b46694e9d 100755 --- a/model_zoo/official/cv/faster_rcnn/ascend310_infer/src/main.cpp +++ b/model_zoo/official/cv/faster_rcnn/ascend310_infer/src/main.cpp @@ -79,7 +79,11 @@ int main(int argc, char* argv[]) { return ret; } if (is_file(FLAGS_data_path)) { - aclProcess.Process(FLAGS_data_path, &costTime_map); + ret = aclProcess.Process(FLAGS_data_path, &costTime_map); + if (ret != OK) { + std::cout << "model process failed, errno = " << ret << std::endl; + return ret; + } } else if (is_dir(FLAGS_data_path)) { struct dirent * filename; DIR * dir; @@ -93,7 +97,11 @@ int main(int argc, char* argv[]) { continue; } std::string wholePath = FLAGS_data_path + "/" + filename->d_name; - aclProcess.Process(wholePath, &costTime_map); + ret = aclProcess.Process(wholePath, &costTime_map); + if (ret != OK) { + std::cout << "model process failed, errno = " << ret << std::endl; + return ret; + } } } else { std::cout << " input image path error" << std::endl; diff --git a/model_zoo/official/cv/maskrcnn/ascend310_infer/src/AclProcess.cpp b/model_zoo/official/cv/maskrcnn/ascend310_infer/src/AclProcess.cpp index 20b0e45b3cf..addafb05be0 100755 --- a/model_zoo/official/cv/maskrcnn/ascend310_infer/src/AclProcess.cpp +++ b/model_zoo/official/cv/maskrcnn/ascend310_infer/src/AclProcess.cpp @@ -165,6 +165,10 @@ int AclProcess::WriteResult(const std::string& imageFile) { std::string outFileName = homePath + "/" + fileName; try { FILE * outputFile = fopen(outFileName.c_str(), "wb"); + if (outputFile == nullptr) { + std::cout << "open result file " << outFileName << " failed" << std::endl; + return INVALID_POINTER; + } fwrite(resHostBuf, output_size, sizeof(char), outputFile); fclose(outputFile); outputFile = nullptr; diff --git a/model_zoo/official/cv/maskrcnn/ascend310_infer/src/main.cpp b/model_zoo/official/cv/maskrcnn/ascend310_infer/src/main.cpp index 9c0d831d5ca..fbed56bcc5f 100755 --- a/model_zoo/official/cv/maskrcnn/ascend310_infer/src/main.cpp +++ b/model_zoo/official/cv/maskrcnn/ascend310_infer/src/main.cpp @@ -79,7 +79,11 @@ int main(int argc, char* argv[]) { return ret; } if (is_file(FLAGS_data_path)) { - aclProcess.Process(FLAGS_data_path, &costTime_map); + ret = aclProcess.Process(FLAGS_data_path, &costTime_map); + if (ret != OK) { + std::cout << "model process failed, errno = " << ret << std::endl; + return ret; + } } else if (is_dir(FLAGS_data_path)) { struct dirent * filename; DIR * dir; @@ -93,7 +97,11 @@ int main(int argc, char* argv[]) { continue; } std::string wholePath = FLAGS_data_path + "/" + filename->d_name; - aclProcess.Process(wholePath, &costTime_map); + ret = aclProcess.Process(wholePath, &costTime_map); + if (ret != OK) { + std::cout << "model process failed, errno = " << ret << std::endl; + return ret; + } } } else { std::cout << " input image path error" << std::endl;