diff --git a/mindspore/lite/examples/runtime_extend/main.cc b/mindspore/lite/examples/runtime_extend/main.cc index b2efca43fb1..9f558a16abb 100644 --- a/mindspore/lite/examples/runtime_extend/main.cc +++ b/mindspore/lite/examples/runtime_extend/main.cc @@ -14,206 +14,97 @@ * limitations under the License. */ -#include -#include #include -#include #include -#include "include/errorcode.h" -#include "include/model.h" -#include "include/context.h" -#include "include/lite_session.h" +#include +#include +#include +#include "include/api/status.h" +#include "include/api/context.h" +#include "include/api/model.h" + namespace mindspore { namespace lite { namespace { constexpr int kNumPrintOfOutData = 20; -std::string RealPath(const char *path) { - const size_t max = 4096; - if (path == nullptr) { - std::cerr << "path is nullptr" << std::endl; - return ""; +Status FillInputData(const std::vector &inputs) { + for (auto tensor : inputs) { + auto input_data = tensor.MutableData(); + if (input_data == nullptr) { + std::cerr << "MallocData for inTensor failed.\n"; + return kLiteError; + } + std::vector temp(tensor.ElementNum(), 1.0f); + memcpy(input_data, temp.data(), tensor.DataSize()); } - if ((strlen(path)) >= max) { - std::cerr << "path is too long" << std::endl; - return ""; - } - auto resolved_path = std::make_unique(max); - if (resolved_path == nullptr) { - std::cerr << "new resolved_path failed" << std::endl; - return ""; - } -#ifdef _WIN32 - char *real_path = _fullpath(resolved_path.get(), path, 1024); -#else - char *real_path = realpath(path, resolved_path.get()); -#endif - if (real_path == nullptr || strlen(real_path) == 0) { - std::cerr << "file path is not valid : " << path << std::endl; - return ""; - } - std::string res = resolved_path.get(); - return res; -} - -char *ReadFile(const char *file, size_t *size) { - if (file == nullptr) { - std::cerr << "file is nullptr." << std::endl; - return nullptr; - } - - std::ifstream ifs(file); - if (!ifs.good()) { - std::cerr << "file: " << file << " is not exist." << std::endl; - return nullptr; - } - - if (!ifs.is_open()) { - std::cerr << "file: " << file << " open failed." << std::endl; - return nullptr; - } - - ifs.seekg(0, std::ios::end); - *size = ifs.tellg(); - std::unique_ptr buf(new (std::nothrow) char[*size]); - if (buf == nullptr) { - std::cerr << "malloc buf failed, file: " << file << std::endl; - ifs.close(); - return nullptr; - } - - ifs.seekg(0, std::ios::beg); - ifs.read(buf.get(), *size); - ifs.close(); - - return buf.release(); + return kSuccess; } } // namespace -template -void GenerateRandomData(int size, void *data, Distribution distribution) { - std::mt19937 random_engine; - int elements_num = size / sizeof(T); - (void)std::generate_n(static_cast(data), elements_num, - [&distribution, &random_engine]() { return static_cast(distribution(random_engine)); }); -} - -int GenerateInputDataWithRandom(std::vector inputs) { - for (auto tensor : inputs) { - auto input_data = tensor->MutableData(); - if (input_data == nullptr) { - std::cerr << "MallocData for inTensor failed." << std::endl; - return RET_ERROR; - } - GenerateRandomData(tensor->Size(), input_data, std::uniform_real_distribution(1.0f, 1.0f)); +Status CompileAndRun(int argc, const char **argv) { + if (argc < 2) { + std::cerr << "Model file must be provided.\n"; + return kLiteError; } - return RET_OK; -} + // generate context. + auto context = std::make_shared(); + if (context == nullptr) { + std::cerr << "New context failed while running.\n"; + return kLiteError; + } + auto &device_list = context->MutableDeviceInfo(); + std::shared_ptr device_info = std::make_shared(); + device_info->SetProvider("Tutorial"); + device_info->SetProviderDevice("Tutorial"); + device_list.push_back(device_info); -int Run(mindspore::session::LiteSession *session) { - auto inputs = session->GetInputs(); + // build model. + std::string model_file = std::string(argv[1]); + mindspore::Model model; + auto ret = model.Build(model_file, kMindIR, context); + if (ret != kSuccess) { + std::cerr << "build model failed.\n"; + return kLiteError; + } - // Generate random data as input data. - auto ret = GenerateInputDataWithRandom(inputs); - if (ret != RET_OK) { - std::cerr << "Generate Random Input Data failed." << std::endl; + // fill input data. + auto inputs = model.GetInputs(); + ret = FillInputData(inputs); + if (ret != kSuccess) { + std::cerr << "Generate Random Input Data failed.\n"; return ret; } - // Run Inference. - ret = session->RunGraph(); - if (ret != RET_OK) { - std::cerr << "Inference error " << ret << std::endl; + // run model. + std::vector outputs; + ret = model.Predict(inputs, &outputs); + if (ret != kSuccess) { + std::cerr << "run model failed.\n"; return ret; } - // Get Output Tensor Data. - auto out_tensors = session->GetOutputs(); - for (auto tensor : out_tensors) { - std::cout << "tensor name is:" << tensor.first << " tensor size is:" << tensor.second->Size() - << " tensor elements num is:" << tensor.second->ElementsNum() << std::endl; - auto out_data = reinterpret_cast(tensor.second->MutableData()); + // display output result. + for (auto tensor : outputs) { + std::cout << "tensor name is:" << tensor.Name() << " tensor size is:" << tensor.DataSize() + << " tensor elements num is:" << tensor.ElementNum() << std::endl; + auto out_data = std::static_pointer_cast(tensor.Data()); std::cout << "output data is:"; - for (int i = 0; i < tensor.second->ElementsNum() && i <= kNumPrintOfOutData; i++) { - std::cout << out_data[i] << " "; + for (int i = 0; i < tensor.ElementNum() && i <= kNumPrintOfOutData; i++) { + std::cout << out_data.get()[i] << " "; } std::cout << std::endl; } - return RET_OK; -} - -mindspore::session::LiteSession *Compile(mindspore::lite::Model *model) { - // Create and init context. - auto context = std::make_shared(); - if (context == nullptr) { - std::cerr << "New context failed while." << std::endl; - return nullptr; - } - context->device_list_[0].provider_ = "Tutorial"; - context->device_list_[0].provider_device_ = "Tutorial"; - // Create the session. - auto *session = mindspore::session::LiteSession::CreateSession(context.get()); - if (session == nullptr) { - std::cerr << "CreateSession failed while running." << std::endl; - return nullptr; - } - - // Compile graph. - auto ret = session->CompileGraph(model); - if (ret != RET_OK) { - delete session; - std::cerr << "Compile failed while running." << std::endl; - return nullptr; - } - - return session; -} - -int CompileAndRun(int argc, const char **argv) { - if (argc < 2) { - std::cerr << "Model file must be provided.\n"; - return RET_ERROR; - } - // Read model file. - auto model_path = RealPath(argv[1]); - if (model_path.empty()) { - std::cerr << "model path " << argv[1] << " is invalid."; - return RET_ERROR; - } - size_t size = 0; - char *model_buf = ReadFile(model_path.c_str(), &size); - if (model_buf == nullptr) { - std::cerr << "Read model file failed." << std::endl; - return RET_ERROR; - } - // Load the .ms model. - auto model = Model::Import(model_buf, size); - delete[](model_buf); - if (model == nullptr) { - std::cerr << "Import model file failed." << std::endl; - return RET_ERROR; - } - // Compile MindSpore Lite model. - auto session = Compile(model); - if (session == nullptr) { - delete model; - std::cerr << "Create session failed." << std::endl; - return RET_ERROR; - } - // Run inference. - auto ret = Run(session); - if (ret != RET_OK) { - delete model; - delete session; - std::cerr << "MindSpore Lite run failed." << std::endl; - return RET_ERROR; - } - // Delete model buffer. - delete model; - // Delete session buffer. - delete session; - return RET_OK; + return kSuccess; } } // namespace lite } // namespace mindspore -int main(int argc, const char **argv) { return mindspore::lite::CompileAndRun(argc, argv); } +int main(int argc, const char **argv) { + auto ret = mindspore::lite::CompileAndRun(argc, argv); + if (ret != mindspore::kSuccess) { + std::cerr << "run failed.\n"; + return -1; + } + std::cout << "run success.\n"; + return 0; +} diff --git a/mindspore/lite/tools/converter/export_model.cc b/mindspore/lite/tools/converter/export_model.cc index 27e481d8588..732f18e9c84 100644 --- a/mindspore/lite/tools/converter/export_model.cc +++ b/mindspore/lite/tools/converter/export_model.cc @@ -191,7 +191,8 @@ STATUS ExportModel(const FuncGraphPtr &graph, const converter::Flags *flags) { return RET_ERROR; } (void)Manage(mirror_graph, true); - if (!RunOptimizerPass(mirror_graph, {"InferShapePass", "DeleteRedundantTranspose", "DecreaseTransposeAlgo"})) { + if (!RunOptimizerPass(mirror_graph, + {"ToNHWCFormat", "InferShapePass", "DeleteRedundantTranspose", "DecreaseTransposeAlgo"})) { MS_LOG(ERROR) << "Run transpose opt pass failed."; return RET_ERROR; } diff --git a/mindspore/lite/tools/converter/parser/unify_format.cc b/mindspore/lite/tools/converter/parser/unify_format.cc index d89b4f35cd2..9fd24b25419 100644 --- a/mindspore/lite/tools/converter/parser/unify_format.cc +++ b/mindspore/lite/tools/converter/parser/unify_format.cc @@ -177,7 +177,8 @@ void UnifyFormatToNHWC::SetSensitiveOps() { sensitive_ops_.insert(sensitive_nchw_ops.begin(), sensitive_nchw_ops.end()); } -bool UnifyFormatToNHWC::DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ShapeVector &shape) { +bool UnifyFormatToNHWC::DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ParameterPtr &input, + const ShapeVector &shape) { MS_ASSERT(func_graph != nullptr); if (fmk_type_ == converter::kFmkTypeTf || fmk_type_ == converter::kFmkTypeTflite) { return false; diff --git a/mindspore/lite/tools/converter/parser/unify_format.h b/mindspore/lite/tools/converter/parser/unify_format.h index 37850ea6ced..0694a251a92 100644 --- a/mindspore/lite/tools/converter/parser/unify_format.h +++ b/mindspore/lite/tools/converter/parser/unify_format.h @@ -36,7 +36,8 @@ class UnifyFormatToNHWC : public opt::ToFormatBase { STATUS ConvertOnnxResizeForVariableShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode); STATUS GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) override; void SetSensitiveOps() override; - bool DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ShapeVector &shape) override; + bool DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ParameterPtr &input, + const ShapeVector &shape) override; bool DecideWhetherInferShapeForNewNode() override; STATUS DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode, schema::Format *src_format, schema::Format *dst_format) override; diff --git a/mindspore/lite/tools/optimizer/format/to_format_base.cc b/mindspore/lite/tools/optimizer/format/to_format_base.cc index 5e46a31c170..d8bdbfbf321 100644 --- a/mindspore/lite/tools/optimizer/format/to_format_base.cc +++ b/mindspore/lite/tools/optimizer/format/to_format_base.cc @@ -178,6 +178,31 @@ STATUS ToFormatBase::InsertPostTransNode(const FuncGraphPtr &func_graph, const C return lite::RET_OK; } +bool ToFormatBase::DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ParameterPtr &input, + const ShapeVector &shape) { + MS_ASSERT(func_graph != nullptr && input != nullptr); + auto manager = func_graph->manager(); + MS_ASSERT(anager != nullptr); + auto node_users = manager->node_users()[input]; + for (auto &node_user : node_users) { + auto post_node = node_user.first; + if (!utils::isa(post_node)) { + continue; + } + auto post_cnode = post_node->cast(); + auto prim = GetValueNode(post_cnode->input(0)); + MS_ASSERT(prim != nullptr); + if (prim->GetAttr(ops::kFormat) != nullptr) { + auto node_format = GetValue(prim->GetAttr(ops::kFormat)); + if (node_format == format_) { + MS_LOG(DEBUG) << "this graph input don't need to change."; + return false; + } + } + } + return true; +} + STATUS ToFormatBase::HandleGraphInput(const FuncGraphPtr &func_graph) { MS_ASSERT(func_graph != nullptr); auto graph_input = func_graph->get_inputs(); @@ -191,7 +216,7 @@ STATUS ToFormatBase::HandleGraphInput(const FuncGraphPtr &func_graph) { MS_LOG(ERROR) << "fetch shape failed." << input->fullname_with_scope(); return lite::RET_ERROR; } - if (shape.size() != kDimNumber || !DecideWhetherHandleGraphInput(func_graph, shape)) { + if (shape.size() != kDimNumber || !DecideWhetherHandleGraphInput(func_graph, input_param, shape)) { continue; } ShapeVector transfer_shape; @@ -298,8 +323,6 @@ bool ToFormatBase::BasicProcess(const FuncGraphPtr &func_graph, bool main_graph) STATUS ToFormatBase::ConvWeightFormatTrans(const FuncGraphPtr &graph, std::set *has_visited) { MS_ASSERT(graph != nullptr && has_visited != nullptr); auto node_list = TopoSort(graph->get_return()); - schema::Format src_format = schema::Format_NUM_OF_FORMAT; - schema::Format dst_format = schema::Format_NUM_OF_FORMAT; for (auto &node : node_list) { if (!utils::isa(node)) { continue; @@ -335,6 +358,8 @@ STATUS ToFormatBase::ConvWeightFormatTrans(const FuncGraphPtr &graph, std::setinsert(node); + schema::Format src_format = schema::Format_NUM_OF_FORMAT; + schema::Format dst_format = schema::Format_NUM_OF_FORMAT; if (DecideConvWeightSrcAndDstFormat(cnode, &src_format, &dst_format) != lite::RET_OK) { MS_LOG(ERROR) << "weight's src format and dst format get failed."; return lite::RET_ERROR; diff --git a/mindspore/lite/tools/optimizer/format/to_format_base.h b/mindspore/lite/tools/optimizer/format/to_format_base.h index 6c6765c9f41..6f8ae6cadac 100644 --- a/mindspore/lite/tools/optimizer/format/to_format_base.h +++ b/mindspore/lite/tools/optimizer/format/to_format_base.h @@ -52,7 +52,8 @@ class ToFormatBase : public Pass { protected: virtual STATUS GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) = 0; virtual void SetSensitiveOps() { sensitive_ops_ = opt::GetNHWCOpMap(); } - virtual bool DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ShapeVector &shape) { return true; } + virtual bool DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ParameterPtr &input, + const ShapeVector &shape); virtual bool DecideWhetherInferShapeForNewNode() { return true; } virtual STATUS DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode, schema::Format *src_format, schema::Format *dst_format) = 0; diff --git a/mindspore/lite/tools/optimizer/format/to_nchw_format.cc b/mindspore/lite/tools/optimizer/format/to_nchw_format.cc index b7d853e5e13..150f4640d93 100644 --- a/mindspore/lite/tools/optimizer/format/to_nchw_format.cc +++ b/mindspore/lite/tools/optimizer/format/to_nchw_format.cc @@ -45,6 +45,20 @@ STATUS ToNCHWFormat::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTyp STATUS ToNCHWFormat::DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode, schema::Format *src_format, schema::Format *dst_format) { MS_ASSERT(cnode != nullptr && src_format != nullptr && dst_format != nullptr); + auto prim = GetValueNode(cnode->input(0)); + MS_ASSERT(prim != nullptr); + if (prim->GetAttr(ops::kFormat) != nullptr) { + auto node_format = GetValue(prim->GetAttr(ops::kFormat)); + if (node_format == mindspore::NCHW) { + MS_LOG(DEBUG) << "node's format has been nchw, no need to transfer, " << cnode->fullname_with_scope(); + return lite::RET_OK; + } + if (node_format != mindspore::NHWC) { + MS_LOG(ERROR) << "node's format is invalid, which must be nhwc or nchw, now is " << node_format + << ", node name is " << cnode->fullname_with_scope(); + return lite::RET_ERROR; + } + } *src_format = schema::Format_KHWC; *dst_format = schema::Format_KCHW; return lite::RET_OK; diff --git a/mindspore/lite/tools/optimizer/format/to_nhwc_format.cc b/mindspore/lite/tools/optimizer/format/to_nhwc_format.cc index 33f786772db..a8122f11c25 100644 --- a/mindspore/lite/tools/optimizer/format/to_nhwc_format.cc +++ b/mindspore/lite/tools/optimizer/format/to_nhwc_format.cc @@ -45,6 +45,19 @@ STATUS ToNHWCFormat::GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTyp STATUS ToNHWCFormat::DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode, schema::Format *src_format, schema::Format *dst_format) { MS_ASSERT(cnode != nullptr && src_format != nullptr && dst_format != nullptr); + auto prim = GetValueNode(cnode->input(0)); + if (prim->GetAttr(ops::kFormat) != nullptr) { + auto node_format = GetValue(prim->GetAttr(ops::kFormat)); + if (node_format == mindspore::NHWC) { + MS_LOG(DEBUG) << "node's format has been nhwc, no need to transfer, " << cnode->fullname_with_scope(); + return lite::RET_OK; + } + if (node_format != mindspore::NCHW) { + MS_LOG(ERROR) << "node's format is invalid, which must be nhwc or nchw, now is " << node_format + << ", node name is " << cnode->fullname_with_scope(); + return lite::RET_ERROR; + } + } *src_format = schema::Format_KCHW; *dst_format = schema::Format_KHWC; return lite::RET_OK; diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc index 9a94ed61740..a1ed3657d12 100644 --- a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc @@ -136,7 +136,7 @@ bool InferShapePass::JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph) { } else { all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph); } - sub_func_graph = GetValueNode(cnode->input(1)); + sub_func_graph = GetValueNode(cnode->input(kInputIndexTwo)); if (sub_func_graph == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); all_op_can_infer = false;