From a3b921891966334ac795bbcfee3bd72943f21c2e Mon Sep 17 00:00:00 2001 From: xuyongfei Date: Mon, 23 Nov 2020 17:07:53 +0800 Subject: [PATCH] support mindpsore::api:Model 910 inference, support 310 model convert in python env --- include/api/model.h | 6 + .../session/ascend_inference_session.cc | 25 -- .../session/ascend_inference_session.h | 1 - .../ccsrc/backend/session/infer_session.cc | 3 +- .../ccsrc/backend/session/session_basic.cc | 57 +++ .../ccsrc/backend/session/session_basic.h | 5 +- mindspore/ccsrc/cxx_api/CMakeLists.txt | 29 +- .../ccsrc/cxx_api/model/acl/acl_model.cc | 1 - .../cxx_api/model/acl/model_converter.cc | 129 +++++- .../ccsrc/cxx_api/model/acl/model_converter.h | 3 + mindspore/ccsrc/cxx_api/model/model.cc | 8 + .../model_converter_utils/multi_process.cc | 207 +++++++++ .../model_converter_utils/multi_process.h | 68 +++ .../model_converter_utils/shared_memory.cc | 69 +++ .../model_converter_utils/shared_memory.h | 41 ++ mindspore/ccsrc/cxx_api/model/model_impl.h | 8 +- mindspore/ccsrc/cxx_api/model/ms/ms_model.cc | 418 ++++++++++++++++++ mindspore/ccsrc/cxx_api/model/ms/ms_model.h | 85 ++++ 18 files changed, 1123 insertions(+), 40 deletions(-) create mode 100644 mindspore/ccsrc/cxx_api/model/model_converter_utils/multi_process.cc create mode 100644 mindspore/ccsrc/cxx_api/model/model_converter_utils/multi_process.h create mode 100644 mindspore/ccsrc/cxx_api/model/model_converter_utils/shared_memory.cc create mode 100644 mindspore/ccsrc/cxx_api/model/model_converter_utils/shared_memory.h create mode 100644 mindspore/ccsrc/cxx_api/model/ms/ms_model.cc create mode 100644 mindspore/ccsrc/cxx_api/model/ms/ms_model.h diff --git a/include/api/model.h b/include/api/model.h index e14b778b491..6378d45a9fb 100644 --- a/include/api/model.h +++ b/include/api/model.h @@ -50,9 +50,15 @@ class MS_API Model { Status GetInputsInfo(std::vector *tensor_list) const; Status GetOutputsInfo(std::vector *tensor_list) const; + static bool CheckModelSupport(const std::string& device_type, ModelType model_type); + private: std::shared_ptr impl_; }; + +extern MS_API const char* kDeviceTypeAscendCL; +extern MS_API const char* kDeviceTypeAscendMS; + } // namespace api } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_MODEL_H diff --git a/mindspore/ccsrc/backend/session/ascend_inference_session.cc b/mindspore/ccsrc/backend/session/ascend_inference_session.cc index 423825bc028..c9f1dca6d9a 100644 --- a/mindspore/ccsrc/backend/session/ascend_inference_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_inference_session.cc @@ -213,30 +213,5 @@ std::string AscendInferenceSession::InputsInfo(const std::vector & return graph + " " + actual; } -void AscendInferenceSession::GetModelInputsInfo(uint32_t graph_id, std::vector *inputs) const { - MS_LOG(INFO) << "Start get model inputs, graph id : " << graph_id; - auto kernel_graph = GetGraph(graph_id); - MS_EXCEPTION_IF_NULL(kernel_graph); - auto kernel_graph_inputs = kernel_graph->inputs(); - vector paras; - // find parameters of graph inputs - for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) { - if (!kernel_graph_inputs[i]->isa()) { - MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter."; - continue; - } - auto parameter = kernel_graph_inputs[i]->cast(); - if (!AnfAlgo::IsParameterWeight(parameter)) { - vector input_shape; - auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0); - (void)std::transform(parameter_shape.begin(), parameter_shape.end(), std::back_inserter(input_shape), - [](const size_t dim) { return SizeToLong(dim); }); - auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter); - auto data_type = kernel_build_info->GetOutputDeviceType(0); - auto ms_tensor = std::make_shared(data_type, input_shape); - inputs->push_back(ms_tensor); - } - } -} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/ascend_inference_session.h b/mindspore/ccsrc/backend/session/ascend_inference_session.h index e75243a4434..671d2e09c73 100644 --- a/mindspore/ccsrc/backend/session/ascend_inference_session.h +++ b/mindspore/ccsrc/backend/session/ascend_inference_session.h @@ -44,7 +44,6 @@ class AscendInferenceSession : public AscendSession { template std::string PrintInputShape(std::vector shape) const; std::string InputsInfo(const std::vector ¶s, const std::vector &inputs) const; - void GetModelInputsInfo(uint32_t graph_id, std::vector *inputs) const override; protected: GraphId CompileGraphImpl(NotNull func_graph) override; diff --git a/mindspore/ccsrc/backend/session/infer_session.cc b/mindspore/ccsrc/backend/session/infer_session.cc index 798bdc0a58b..d82a9f30087 100644 --- a/mindspore/ccsrc/backend/session/infer_session.cc +++ b/mindspore/ccsrc/backend/session/infer_session.cc @@ -370,7 +370,8 @@ Status MSInferSession::CheckModelInputs(uint32_t graph_id, const std::vector *tensor_list) const { vector inputs; - session_impl_->GetModelInputsInfo(model_id, &inputs); + vector input_names; + session_impl_->GetModelInputsInfo(model_id, &inputs, &input_names); if (inputs.size() == 0) { MS_LOG(ERROR) << "The model inputs is NULL"; return FAILED; diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 164601380fc..31114356caf 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -34,6 +34,8 @@ #include "ir/func_graph_cloner.h" #include "utils/utils.h" #include "debug/anf_ir_dump.h" +#include "mindspore/core/base/base_ref_utils.h" + #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #include "ps/worker.h" #include "ps/util.h" @@ -1089,6 +1091,61 @@ void SessionBasic::CreateOutputTensors(const GraphId &graph_id, const std::vecto } } +void SessionBasic::GetModelInputsInfo(uint32_t graph_id, std::vector *inputs, + std::vector *inputs_name) const { + MS_LOG(INFO) << "Start get model inputs, graph id : " << graph_id; + auto kernel_graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_EXCEPTION_IF_NULL(inputs); + MS_EXCEPTION_IF_NULL(inputs_name); + auto kernel_graph_inputs = kernel_graph->inputs(); + vector paras; + // find parameters of graph inputs + for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) { + if (!kernel_graph_inputs[i]->isa()) { + MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter."; + continue; + } + auto parameter = kernel_graph_inputs[i]->cast(); + if (!AnfAlgo::IsParameterWeight(parameter)) { + vector input_shape; + auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0); + (void)std::transform(parameter_shape.begin(), parameter_shape.end(), std::back_inserter(input_shape), + [](const size_t dim) { return SizeToLong(dim); }); + auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter); + auto data_type = kernel_build_info->GetOutputDeviceType(0); + auto ms_tensor = std::make_shared(data_type, input_shape); + inputs->push_back(ms_tensor); + inputs_name->push_back(parameter->name()); + } + } +} + +void SessionBasic::GetModelOutputsInfo(uint32_t graph_id, std::vector *outputs, + std::vector *output_names) const { + std::vector inputs; + std::vector input_names; + GetModelInputsInfo(graph_id, &inputs, &input_names); + + auto kernel_graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_EXCEPTION_IF_NULL(outputs); + MS_EXCEPTION_IF_NULL(output_names); + + VectorRef vector_outputs; + std::map tensor_to_node; + auto anf_outputs = kernel_graph->outputs(); + for (auto &item : anf_outputs) { + MS_EXCEPTION_IF_NULL(item); + MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]"; + vector_outputs.emplace_back(CreateNodeOutputTensors(item, kernel_graph, inputs, &tensor_to_node)); + } + *outputs = TransformVectorRefToMultiTensor(vector_outputs); + for (size_t i = 0; i < outputs->size(); i++) { + output_names->push_back("output" + std::to_string(i)); + } +} + void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) { MS_EXCEPTION_IF_NULL(callback); summary_callback_ = callback; diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 94947e143a4..8ce61db5898 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -102,7 +102,10 @@ class SessionBasic : public std::enable_shared_from_this { std::string *error_msg) const { return true; } - virtual void GetModelInputsInfo(uint32_t graph_id, std::vector *inputs) const {} + void GetModelInputsInfo(uint32_t graph_id, std::vector *inputs, + std::vector *inputs_name) const; + void GetModelOutputsInfo(uint32_t graph_id, std::vector *outputs, + std::vector *outputs_name) const; std::vector GetInputNeedLockTensors(const GraphId &graph_id, const std::vector &inputs); // Get graph by graph id, if not exist return null ptr diff --git a/mindspore/ccsrc/cxx_api/CMakeLists.txt b/mindspore/ccsrc/cxx_api/CMakeLists.txt index 23d6c4935b6..2f4954c89bc 100644 --- a/mindspore/ccsrc/cxx_api/CMakeLists.txt +++ b/mindspore/ccsrc/cxx_api/CMakeLists.txt @@ -6,22 +6,25 @@ set(LOAD_ONNX_SRC file(GLOB_RECURSE API_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR} "ops/*.cc") if (ENABLE_ACL) - file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/acl/*.cc") + file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/acl/*.cc" "model/model_converter_utils/*.cc") +elseif (ENABLE_D) + file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/ms/*.cc") endif () set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc - ${CMAKE_CURRENT_SOURCE_DIR}/cell.cc - ${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc - ${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc - ${API_ACL_SRC} - ${API_OPS_SRC} - ${LOAD_ONNX_SRC}) + ${CMAKE_CURRENT_SOURCE_DIR}/cell.cc + ${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc + ${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc + ${API_MS_INFER_SRC} + ${API_ACL_SRC} + ${API_OPS_SRC} + ${LOAD_ONNX_SRC}) add_library(mindspore_shared_lib SHARED ${MSLIB_SRC}) set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore PUBLIC_HEADER "${API_INCLUDE}") target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY} - -Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_gvar mindspore::protobuf) + -Wl,--whole-archive mindspore -Wl,--no-whole-archive proto_input mindspore_gvar mindspore::protobuf) if (ENABLE_CPU) target_link_libraries(mindspore_shared_lib PRIVATE mindspore::dnnl mindspore::mkldnn) @@ -58,5 +61,13 @@ if (ENABLE_ACL) find_library(acl_runtime libruntime.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64) find_library(ge_compiler libge_compiler.so ${ATC_DIR}/lib64 ${ATLAS_ATC_DIR}/lib64) target_link_libraries(mindspore_shared_lib PRIVATE ${acl} ${acl_retr} ${acl_cblas} ${acl_dvpp} ${acl_runtime} - ${ge_compiler} mindspore::jpeg_turbo) + ${ge_compiler} mindspore::jpeg_turbo) endif () + + +# Before build inference +if (ENABLE_D) + find_library(adump_server libadump_server.a ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) + target_link_libraries(mindspore_shared_lib PRIVATE ${adump_server}) +endif () + diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc b/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc index 10ecfc5b05e..754a5808d55 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model.cc @@ -91,7 +91,6 @@ Status AclModel::InitEnv() { MS_LOG(ERROR) << "DVPP init resource failed"; return FAILED; } - ModelConverter::RegAllOp(); MS_LOG(INFO) << "Init acl success, device id " << device_id_; init_flag_ = true; diff --git a/mindspore/ccsrc/cxx_api/model/acl/model_converter.cc b/mindspore/ccsrc/cxx_api/model/acl/model_converter.cc index bc79310aa21..3fde01a4d6e 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/model_converter.cc +++ b/mindspore/ccsrc/cxx_api/model/acl/model_converter.cc @@ -24,6 +24,7 @@ #include "backend/kernel_compiler/oplib/oplib.h" #include "graph/model.h" +#include "cxx_api/model/model_converter_utils/multi_process.h" namespace py = pybind11; @@ -238,6 +239,131 @@ Buffer ModelConverter::ReadFile(const std::string &file) { } Buffer ModelConverter::LoadMindIR(const Buffer &model_data) { + if (Py_IsInitialized() == 0) { + MS_LOG_INFO << "Call LoadMindIRInner directly"; + return LoadMindIRInner(model_data); + } + MultiProcess multi_process; + Buffer buffer_ret; + auto parent_process = [&model_data, &buffer_ret](MultiProcess *multi_process) -> Status { + MS_EXCEPTION_IF_NULL(multi_process); + // send original model to child + auto status = multi_process->SendMsg(model_data.Data(), model_data.DataSize()); + if (!status.IsSuccess()) { + MS_LOG_ERROR << "Send original model to child process failed"; + return FAILED; + } + // receive convert model result from child + CreateBufferCall call = [&buffer_ret](size_t msg_len) -> uint8_t * { + buffer_ret.ResizeData(msg_len); + return reinterpret_cast(buffer_ret.MutableData()); + }; + status = multi_process->ReceiveMsg(call); + if (!status.IsSuccess()) { + MS_LOG_ERROR << "Receive result model from child process failed"; + return FAILED; + } + return SUCCESS; + }; + auto child_process = [this](MultiProcess *multi_process) -> Status { + MS_EXCEPTION_IF_NULL(multi_process); + // receive original model from parent + Buffer model; + CreateBufferCall call = [&model](size_t msg_len) -> uint8_t * { + model.ResizeData(msg_len); + return reinterpret_cast(model.MutableData()); + }; + auto status = multi_process->ReceiveMsg(call); + if (!status.IsSuccess()) { + MS_LOG_ERROR << "Receive original model from parent process failed"; + return FAILED; + } + Buffer model_result = LoadMindIRInner(model); + if (model_result.DataSize() == 0) { + MS_LOG_ERROR << "Convert model from MindIR to OM failed"; + return FAILED; + } + // send result model to parent + status = multi_process->SendMsg(model_result.Data(), model_result.DataSize()); + if (!status.IsSuccess()) { + MS_LOG_ERROR << "Send result model to parent process failed"; + return FAILED; + } + return SUCCESS; + }; + auto status = multi_process.MainProcess(parent_process, child_process); + if (!status.IsSuccess()) { + MS_LOG_ERROR << "Convert MindIR model to OM model failed"; + } else { + MS_LOG_INFO << "Convert MindIR model to OM model success"; + } + return buffer_ret; +} + +Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) { + if (Py_IsInitialized() == 0) { + MS_LOG_INFO << "Call LoadAscendIRInner directly"; + return LoadAscendIRInner(model_data); + } + MultiProcess multi_process; + Buffer buffer_ret; + auto parent_process = [&model_data, &buffer_ret](MultiProcess *multi_process) -> Status { + MS_EXCEPTION_IF_NULL(multi_process); + // send original model to child + auto status = multi_process->SendMsg(model_data.Data(), model_data.DataSize()); + if (!status.IsSuccess()) { + MS_LOG_ERROR << "Send original model to child process failed"; + return FAILED; + } + // receive convert model result from child + CreateBufferCall call = [&buffer_ret](size_t msg_len) -> uint8_t * { + buffer_ret.ResizeData(msg_len); + return reinterpret_cast(buffer_ret.MutableData()); + }; + status = multi_process->ReceiveMsg(call); + if (!status.IsSuccess()) { + MS_LOG_ERROR << "Receive result model from child process failed"; + return FAILED; + } + return SUCCESS; + }; + auto child_process = [this](MultiProcess *multi_process) -> Status { + MS_EXCEPTION_IF_NULL(multi_process); + // receive original model from parent + Buffer model; + CreateBufferCall call = [&model](size_t msg_len) -> uint8_t * { + model.ResizeData(msg_len); + return reinterpret_cast(model.MutableData()); + }; + auto status = multi_process->ReceiveMsg(call); + if (!status.IsSuccess()) { + MS_LOG_ERROR << "Receive original model from parent process failed"; + return FAILED; + } + Buffer model_result = LoadAscendIRInner(model); + if (model_result.DataSize() == 0) { + MS_LOG_ERROR << "Convert model from AIR to OM failed"; + return FAILED; + } + // send result model to parent + status = multi_process->SendMsg(model_result.Data(), model_result.DataSize()); + if (!status.IsSuccess()) { + MS_LOG_ERROR << "Send result model to parent process failed"; + return FAILED; + } + return SUCCESS; + }; + auto status = multi_process.MainProcess(parent_process, child_process); + if (!status.IsSuccess()) { + MS_LOG_ERROR << "Convert AIR model to OM model failed"; + } else { + MS_LOG_INFO << "Convert AIR model to OM model success"; + } + return buffer_ret; +} + +Buffer ModelConverter::LoadMindIRInner(const Buffer &model_data) { + RegAllOp(); auto func_graph = ConvertMindIrToFuncGraph(model_data); if (func_graph == nullptr) { MS_LOG(ERROR) << "Convert MindIR to FuncGraph failed."; @@ -259,7 +385,8 @@ Buffer ModelConverter::LoadMindIR(const Buffer &model_data) { return om_data; } -Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) { +Buffer ModelConverter::LoadAscendIRInner(const Buffer &model_data) { + RegAllOp(); ge::Model load_model = ge::Model("loadmodel", "version2"); ge::Status ret = ge::Model::Load(reinterpret_cast(model_data.Data()), model_data.DataSize(), load_model); diff --git a/mindspore/ccsrc/cxx_api/model/acl/model_converter.h b/mindspore/ccsrc/cxx_api/model/acl/model_converter.h index 6189ad56def..21d34ed3366 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/model_converter.h +++ b/mindspore/ccsrc/cxx_api/model/acl/model_converter.h @@ -45,6 +45,9 @@ class ModelConverter { transform::DfGraphPtr ConvertFuncGraphToAIR(const FuncGraphPtr &anf_graph); Buffer BuildAirModel(const transform::DfGraphPtr &graph, const std::map &acl_options); AclModelOptions *options_; + + Buffer LoadMindIRInner(const Buffer &model_data); + Buffer LoadAscendIRInner(const Buffer &model_data); }; } // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/model/model.cc b/mindspore/ccsrc/cxx_api/model/model.cc index 27a27f3797b..0d3e5aee62d 100644 --- a/mindspore/ccsrc/cxx_api/model/model.cc +++ b/mindspore/ccsrc/cxx_api/model/model.cc @@ -18,6 +18,9 @@ #include "utils/utils.h" namespace mindspore::api { +const char *kDeviceTypeAscendCL = "AscendCL"; +const char *kDeviceTypeAscendMS = "AscendMS"; + Status Model::LoadModel(const Buffer &model_data, ModelType type, const std::map &options) { MS_EXCEPTION_IF_NULL(impl_); return impl_->LoadModel(model_data, type, options); @@ -95,4 +98,9 @@ Model::Model(NetWork network, const std::string &device_type, uint32_t device_id } Model::~Model() {} + +bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) { + return ModelFactory::Instance().CheckModelSupport(device_type, model_type); +} + } // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/model/model_converter_utils/multi_process.cc b/mindspore/ccsrc/cxx_api/model/model_converter_utils/multi_process.cc new file mode 100644 index 00000000000..9e8866d8011 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/model/model_converter_utils/multi_process.cc @@ -0,0 +1,207 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cxx_api/model/model_converter_utils/multi_process.h" +#include +#include +#include +#include +#include +#include "mindspore/core/utils/log_adapter.h" +#include "cxx_api/model/model_converter_utils/shared_memory.h" + +namespace mindspore { +namespace api { + +namespace { +uint64_t kSharedMemorySize = 100ull << 20; // 100 MB +} + +MultiProcess::MultiProcess() = default; + +MultiProcess::~MultiProcess() = default; + +Status MultiProcess::MainProcess(ProcessFuncCall parent_process, ProcessFuncCall child_process) { + MS_EXCEPTION_IF_NULL(parent_process); + MS_EXCEPTION_IF_NULL(child_process); + Status ret; + memory_size_ = kSharedMemorySize; // 100 MB + SharedMemory shared_memory; + ret = shared_memory.Create(memory_size_); + if (!ret.IsSuccess()) { + MS_LOG_ERROR << "Create shared memory failed"; + return ret; + } + pid_t pid = fork(); + if (pid < 0) { + shared_memory.Destroy(); + MS_LOG_ERROR << "Fork process to convert model failed"; + return FAILED; + } + ret = shared_memory.Attach(); + if (!ret.IsSuccess()) { + MS_LOG_ERROR << "Process attach shared memory failed, pid " << pid; + return ret; + } + shmat_addr_ = shared_memory.GetSharedMemoryAddr(); + if (shmat_addr_ == nullptr) { + MS_LOG_ERROR << "Get shared memory failed"; + return ret; + } + shmat_data_addr_ = shmat_addr_ + sizeof(MessageFlag) * 2; + shmat_data_max_size_ = memory_size_ - (shmat_data_addr_ - shmat_addr_); + + MS_LOG_INFO << "Shm addr " << (uint64_t)shmat_addr_; + if (pid == 0) { + ChildProcess(child_process); + shared_memory.Detach(); + MS_LOG_INFO << "Model converter: child process exit"; + exit(0); + } else { // parent process + ret = ParentProcess(parent_process); + shared_memory.Detach(); + int status; + wait(&status); + shared_memory.Destroy(); + } + return ret; +} + +Status MultiProcess::ParentProcess(ProcessFuncCall parent_process) { + auto parent_msg = reinterpret_cast(shmat_addr_); + auto child_msg = reinterpret_cast(shmat_addr_ + sizeof(MessageFlag)); + send_msg_ = parent_msg; + receive_msg_ = child_msg; + std::thread heartbeat_thread(MultiProcess::HeartbeatThreadFunc, this); + Status ret; + try { + ret = parent_process(this); + if (!ret.IsSuccess()) { + MS_LOG_ERROR << "Parent process process failed"; + } + } catch (const std::runtime_error &ex) { + MS_LOG_ERROR << "Catch parent process runtime error: " << ex.what(); + ret = FAILED; + } + stopped_ = true; + send_msg_->stop = true; + heartbeat_thread.join(); + return ret; +} + +void MultiProcess::ChildProcess(ProcessFuncCall child_process) { + auto parent_msg = reinterpret_cast(shmat_addr_); + auto child_msg = reinterpret_cast(shmat_addr_ + sizeof(MessageFlag)); + send_msg_ = child_msg; + receive_msg_ = parent_msg; + std::thread heartbeat_thread(MultiProcess::HeartbeatThreadFunc, this); + try { + auto ret = child_process(this); + if (!ret.IsSuccess()) { + MS_LOG_ERROR << "Child process process failed"; + } + } catch (const std::runtime_error &ex) { + MS_LOG_ERROR << "Catch child process runtime error: " << ex.what(); + } + stopped_ = true; + send_msg_->stop = true; + heartbeat_thread.join(); +} + +Status MultiProcess::SendMsg(const void *buffer, uint64_t msg_len) { + MS_LOG_INFO << "Start to send message to peer process, msg len " << msg_len; + send_msg_->msg_total_len = msg_len; + uint64_t cur_offset = 0; + while (msg_len > cur_offset) { + uint64_t sub_msg_len = std::min(msg_len - cur_offset, shmat_data_max_size_); + + memcpy_s(shmat_data_addr_, shmat_data_max_size_, static_cast(buffer) + cur_offset, sub_msg_len); + cur_offset += sub_msg_len; + + send_msg_->msg_len = sub_msg_len; + send_msg_->read_finish_flag = false; + send_msg_->read_ready_flag = true; + MS_LOG_INFO << "Send start " << cur_offset << ", msg len " << sub_msg_len << ", total len " << msg_len; + while (!send_msg_->read_finish_flag && !peer_stopped_) { + usleep(1000); // 1ms + } + if (peer_stopped_) { + if (!send_msg_->read_finish_flag) { + return FAILED; + } + break; + } + MS_LOG_INFO << "Send end " << cur_offset << ", msg len " << sub_msg_len << ", total len " << msg_len; + } + MS_LOG_INFO << "End to send message to peer process, msg len " << msg_len; + return SUCCESS; +} + +Status MultiProcess::ReceiveMsg(CreateBufferCall create_buffer_call) { + uint64_t cur_offset = 0; + uint8_t *msg_buffer = nullptr; + uint64_t msg_len = 0; + do { + MS_LOG_INFO << "Receive start from " << cur_offset; + while (!receive_msg_->read_ready_flag && !peer_stopped_) { + usleep(1000); // 1ms + } + if (peer_stopped_) { + return FAILED; + } + if (msg_buffer == nullptr) { + msg_len = receive_msg_->msg_total_len; + msg_buffer = create_buffer_call(msg_len); + } + memcpy_s(msg_buffer + cur_offset, msg_len - cur_offset, shmat_data_addr_, receive_msg_->msg_len); + cur_offset += receive_msg_->msg_len; + receive_msg_->read_ready_flag = false; + receive_msg_->read_finish_flag = true; + MS_LOG_INFO << "Receive end, current length " << cur_offset << ", total length " << msg_len << std::endl; + } while (msg_len > cur_offset); + return SUCCESS; +} + +void MultiProcess::HeartbeatThreadFunc(MultiProcess *multi_process) { multi_process->HeartbeatThreadFuncInner(); } + +void MultiProcess::HeartbeatThreadFuncInner() { + uint64_t last_beat_cnt = 0; + uint64_t repeat_cnt = 0; + while (!stopped_) { + if (receive_msg_->stop) { + peer_stopped_ = true; + MS_LOG_WARNING << "Peer stopped"; + break; + } + uint64_t heartbeat_gap = receive_msg_->heartbeat - last_beat_cnt; + if (heartbeat_gap > 0 && heartbeat_gap < 1024) { + last_beat_cnt = receive_msg_->heartbeat; + repeat_cnt = 0; + } else { + repeat_cnt++; + if (repeat_cnt > 30) { // 30*100ms = 3s no reply + peer_stopped_ = true; + MS_LOG_WARNING << "Peer stopped"; + break; + } + } + send_msg_->heartbeat += 1; + usleep(100000); // sleep 100 ms + } +} + +} // namespace api +} // namespace mindspore diff --git a/mindspore/ccsrc/cxx_api/model/model_converter_utils/multi_process.h b/mindspore/ccsrc/cxx_api/model/model_converter_utils/multi_process.h new file mode 100644 index 00000000000..ec384912fef --- /dev/null +++ b/mindspore/ccsrc/cxx_api/model/model_converter_utils/multi_process.h @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H +#define MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H +#include +#include +#include "include/api/status.h" + +namespace mindspore { +namespace api { + +struct MessageFlag { + uint64_t heartbeat = 0; + uint64_t stop = false; + uint64_t msg_len = 0; + uint64_t msg_total_len = 0; + uint64_t read_ready_flag = false; + uint64_t read_finish_flag = false; +}; + +class MultiProcess; +using ProcessFuncCall = std::function; +using CreateBufferCall = std::function; + +class MultiProcess { + public: + MultiProcess(); + ~MultiProcess(); + + Status MainProcess(ProcessFuncCall parent_process, ProcessFuncCall child_process); + Status SendMsg(const void *buffer, uint64_t msg_len); + Status ReceiveMsg(CreateBufferCall create_buffer_call); + + private: + uint8_t *shmat_addr_ = nullptr; + uint8_t *shmat_data_addr_ = nullptr; + uint64_t shmat_data_max_size_ = 0; + uint64_t memory_size_ = 0; + + bool peer_stopped_ = false; + bool stopped_ = false; + MessageFlag *send_msg_ = nullptr; + MessageFlag *receive_msg_ = nullptr; + + static void HeartbeatThreadFunc(MultiProcess *multi_process); + void HeartbeatThreadFuncInner(); + Status ParentProcess(ProcessFuncCall parent_process); + void ChildProcess(ProcessFuncCall child_process); +}; + +} // namespace api +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H diff --git a/mindspore/ccsrc/cxx_api/model/model_converter_utils/shared_memory.cc b/mindspore/ccsrc/cxx_api/model/model_converter_utils/shared_memory.cc new file mode 100644 index 00000000000..46446f65a0b --- /dev/null +++ b/mindspore/ccsrc/cxx_api/model/model_converter_utils/shared_memory.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "cxx_api/model/model_converter_utils/shared_memory.h" +#include +#include +#include +#include "mindspore/core/utils/log_adapter.h" + +namespace mindspore { +namespace api { + +Status SharedMemory::Create(uint64_t memory_size) { + auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP; + shm_id_ = shmget(IPC_PRIVATE, memory_size, IPC_CREAT | IPC_EXCL | access_mode); + if (shm_id_ == -1) { + MS_LOG_ERROR << "Shared memory creation failed. Errno " + std::to_string(errno); + return FAILED; + } + MS_LOG_INFO << "shmget success, shm id " << shm_id_; + return SUCCESS; +} + +Status SharedMemory::Attach() { + void *shmat_addr = shmat(shm_id_, nullptr, 0); + if (shmat_addr == reinterpret_cast(-1)) { + MS_LOG_ERROR << "Shared memory attach failed. Errno " + std::to_string(errno); + return FAILED; + } + shmat_addr_ = reinterpret_cast(shmat_addr); + return SUCCESS; +} + +void SharedMemory::Detach() { + if (shmat_addr_) { + auto err = shmdt(shmat_addr_); + if (err == -1) { + MS_LOG_ERROR << "Shared memory detach failed. Errno " + std::to_string(errno); + return; + } + } + shmat_addr_ = nullptr; +} + +void SharedMemory::Destroy() { + // Remove the shared memory and never mind about the return code. + auto err = shmctl(shm_id_, IPC_RMID, nullptr); + if (err == -1) { + std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id_); + errMsg += ". Errno :" + std::to_string(errno); + errMsg += "\nPlesae remove it manually using ipcrm -m command"; + MS_LOG_ERROR << errMsg; + } +} + +} // namespace api +} // namespace mindspore diff --git a/mindspore/ccsrc/cxx_api/model/model_converter_utils/shared_memory.h b/mindspore/ccsrc/cxx_api/model/model_converter_utils/shared_memory.h new file mode 100644 index 00000000000..b79b3ff6a70 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/model/model_converter_utils/shared_memory.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H +#define MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H +#include +#include "include/api/status.h" + +namespace mindspore { +namespace api { + +class SharedMemory { + public: + Status Create(uint64_t memory_size); + Status Attach(); + void Detach(); + void Destroy(); + uint8_t *GetSharedMemoryAddr() { return shmat_addr_; } + + private: + int shm_id_ = -1; + uint8_t *shmat_addr_ = nullptr; +}; + +} // namespace api +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H diff --git a/mindspore/ccsrc/cxx_api/model/model_impl.h b/mindspore/ccsrc/cxx_api/model/model_impl.h index 3ef26a6c3e9..d513a9a4c81 100644 --- a/mindspore/ccsrc/cxx_api/model/model_impl.h +++ b/mindspore/ccsrc/cxx_api/model/model_impl.h @@ -70,6 +70,12 @@ class ModelFactory { return nullptr; } + bool CheckModelSupport(const std::string &device_type, ModelType /*model_type*/) { + return std::any_of( + model_creators_.begin(), model_creators_.end(), + [&device_type](const std::pair &item) { return item.first == device_type; }); + } + private: ModelFactory() = default; ~ModelFactory() = default; @@ -86,7 +92,7 @@ class ModelRegistrar { #define API_REG_MODEL(DEVICE_NAME, MODEL_CLASS) \ static const ModelRegistrar g_api_model_registrar__##DEVICE_NAME##_##_reg( \ - #DEVICE_NAME, [](uint32_t device_id) { return std::make_shared(device_id); }); + kDeviceType##DEVICE_NAME, [](uint32_t device_id) { return std::make_shared(device_id); }); } // namespace mindspore::api diff --git a/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc b/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc new file mode 100644 index 00000000000..e3484153804 --- /dev/null +++ b/mindspore/ccsrc/cxx_api/model/ms/ms_model.cc @@ -0,0 +1,418 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cxx_api/model/ms/ms_model.h" +#include +#include +#include + +#include "utils/load_onnx/anf_converter.h" +#include "backend/session/session_basic.h" +#include "backend/session/session_factory.h" +#include "backend/session/executor_manager.h" +#include "base/base_ref_utils.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "utils/context/context_extends.h" +#include "runtime/device/kernel_runtime_manager.h" + +#include "pybind11/pybind11.h" +#include "pybind11/embed.h" + +#ifdef ENABLE_D +#include "utils/ms_context.h" +#endif + +using std::string; +using std::vector; + +namespace py = pybind11; +namespace mindspore { +namespace api { + +MsModel::MsModel(uint32_t device_id) : device_id_(device_id) {} +MsModel::~MsModel() = default; + +TypeId TransInferDataType2TypeId(DataType data_type) { + const std::map type2id_map{ + {api::kMsUnknown, TypeId::kNumberTypeBegin}, {api::kMsBool, TypeId::kNumberTypeBool}, + {api::kMsInt8, TypeId::kNumberTypeInt8}, {api::kMsUint8, TypeId::kNumberTypeUInt8}, + {api::kMsInt16, TypeId::kNumberTypeInt16}, {api::kMsUint16, TypeId::kNumberTypeUInt16}, + {api::kMsInt32, TypeId::kNumberTypeInt32}, {api::kMsUint32, TypeId::kNumberTypeUInt32}, + {api::kMsInt64, TypeId::kNumberTypeInt64}, {api::kMsUint64, TypeId::kNumberTypeUInt64}, + {api::kMsFloat16, TypeId::kNumberTypeFloat16}, {api::kMsFloat32, TypeId::kNumberTypeFloat32}, + {api::kMsFloat64, TypeId::kNumberTypeFloat64}, + }; + auto it = type2id_map.find(data_type); + if (it == type2id_map.end()) { + MS_LOG_WARNING << "Unsupported MSI data type " << data_type; + return TypeId::kNumberTypeBegin; + } else { + return it->second; + } +} + +DataType TransTypeId2InferDataType(TypeId type_id) { + const std::map id2type_map{ + {TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool}, + {TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8}, + {TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16}, + {TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32}, + {TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64}, + {TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16}, + {TypeId::kNumberTypeFloat32, api::kMsFloat32}, + }; + auto it = id2type_map.find(type_id); + if (it == id2type_map.end()) { + MS_LOG_WARNING << "Unsupported data id " << type_id; + return api::kMsUnknown; + } else { + return it->second; + } +} + +Buffer MsModel::ReadFile(const std::string &file) { + if (file.empty()) { + MS_LOG(ERROR) << "file is nullptr"; + return Buffer(); + } + std::ifstream ifs(file); + if (!ifs.good()) { + MS_LOG(ERROR) << "file: " << file << " is not exist"; + return Buffer(); + } + + if (!ifs.is_open()) { + MS_LOG(ERROR) << "file: " << file << "open failed"; + return Buffer(); + } + + ifs.seekg(0, std::ios::end); + size_t size = ifs.tellg(); + Buffer buffer; + buffer.ResizeData(size); + ifs.seekg(0, std::ios::beg); + ifs.read(static_cast(buffer.MutableData()), size); + ifs.close(); + + return buffer; +} + +Status MsModel::LoadModel(const Buffer &model_data, ModelType type, const std::map &options) { + auto status = InitEnv({}); + if (status != SUCCESS) { + MS_LOG(ERROR) << "Init env failed"; + return FAILED; + } + std::shared_ptr anf_graph; + try { + anf_graph = + lite::AnfConverter::RunAnfConverter(static_cast(model_data.Data()), model_data.DataSize()); + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference LoadModel failed"; + return FAILED; + } + Status ret = CompileGraph(anf_graph); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Compile graph model failed"; + return FAILED; + } + session_impl_->GetModelInputsInfo(graph_id_, &inputs_, &input_names_); + session_impl_->GetModelOutputsInfo(graph_id_, &outputs_, &output_names_); + if (inputs_.empty() || inputs_.size() != input_names_.size()) { + MS_LOG_ERROR << "Get model inputs info failed"; + return FAILED; + } + if (outputs_.empty() || outputs_.size() != output_names_.size()) { + MS_LOG_ERROR << "Get model outputs info failed"; + return FAILED; + } + MS_LOG(INFO) << "Load model success"; + +#ifdef ENABLE_D + // set d context + rtError_t rt_ret = rtCtxGetCurrent(&context_); + if (rt_ret != RT_ERROR_NONE || context_ == nullptr) { + MS_LOG(ERROR) << "the ascend device context is null"; + return FAILED; + } +#endif + load_flag_ = true; + return SUCCESS; +} + +Status MsModel::LoadModel(const std::string &file_name, ModelType type, + const std::map &options) { + auto graphBuf = ReadFile(file_name); + if (graphBuf.DataSize() == 0) { + MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str(); + return FAILED; + } + auto status = LoadModel(graphBuf, type, options); + if (status != SUCCESS) { + MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); + return FAILED; + } + return SUCCESS; +} + +Status MsModel::UnloadModel() { + if (!load_flag_) { + MS_LOG_ERROR << "Model has not been loaded"; + return FAILED; + } + FinalizeEnv(); + load_flag_ = false; + return SUCCESS; +} + +Status MsModel::Train(const DataSet &, std::map *) { + MS_LOG(ERROR) << "Unsupported feature."; + return FAILED; +} + +Status MsModel::Eval(const DataSet &, std::map *) { + MS_LOG(ERROR) << "Unsupported feature."; + return FAILED; +} + +Status MsModel::Predict(const std::map &inputs, std::map *outputs) { + MS_EXCEPTION_IF_NULL(outputs); + if (!load_flag_) { + MS_LOG(ERROR) << "No model is loaded, predict failed."; + return FAILED; + } + if (inputs.size() != inputs_.size()) { + MS_LOG(ERROR) << "inputs count not match, required count " << inputs_.size() << ", given count " << inputs.size(); + return INVALID_INPUTS; + } + std::vector request; + std::vector reply; + for (size_t i = 0; i < inputs_.size(); ++i) { + const auto &input_name = input_names_[i]; + auto iter = inputs.find(input_name); + if (iter == inputs.end()) { + MS_LOG(ERROR) << "Model missing input " << input_name; + return INVALID_INPUTS; + } + + if (iter->second.DataSize() != inputs_[i]->Size()) { + MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_[i]->Size() << ", given count " + << iter->second.DataSize(); + return INVALID_INPUTS; + } + request.push_back(iter->second); + } + if (ExecuteModel(request, &reply) != SUCCESS) { + MS_LOG(ERROR) << "Execute Model Failed"; + return FAILED; + } + if (outputs_.size() != reply.size()) { + MS_LOG(ERROR) << "Predict output size " << reply.size() << " not match output size got from model info " + << outputs_.size(); + return FAILED; + } + outputs->clear(); + for (size_t i = 0; i < reply.size(); i++) { + outputs->emplace(output_names_[i], reply[i]); + } + return SUCCESS; +} + +Status MsModel::ExecuteModel(const std::vector &request, std::vector *reply) { + MS_EXCEPTION_IF_NULL(reply); +#ifdef ENABLE_D + if (context_ == nullptr) { + MS_LOG(ERROR) << "rtCtx is nullptr"; + return FAILED; + } + rtError_t rt_ret = rtCtxSetCurrent(context_); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "set Ascend rtCtx failed"; + return FAILED; + } +#endif + vector inputs; + for (size_t i = 0; i < request.size(); i++) { + auto &item = request[i]; + auto input = inputs_[i]; + if (input->Size() != item.DataSize()) { + MS_LOG(ERROR) << "Predict input " << i << " data size " << item.DataSize() << " not match model input data size " + << input->Size(); + return FAILED; + } + auto ret = memcpy_s(input->data_c(), input->Size(), item.Data(), item.DataSize()); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Tensor copy failed"; + return FAILED; + } + inputs.push_back(input); + } + vector outputs = RunGraph(inputs); + if (outputs.empty()) { + MS_LOG(ERROR) << "Execute Model Failed"; + return FAILED; + } + reply->clear(); + std::transform(outputs.begin(), outputs.end(), std::back_inserter(*reply), + [](const tensor::TensorPtr &tensor) { return Buffer(tensor->data_c(), tensor->Size()); }); + return SUCCESS; +} + +Status MsModel::FinalizeEnv() { + MS_LOG_INFO << "Start finalize env"; + py::gil_scoped_acquire acquire; + session::ExecutorManager::Instance().Clear(); + device::KernelRuntimeManager::Instance().ClearRuntimeResource(); + auto ms_context = MsContext::GetInstance(); + if (ms_context == nullptr) { + MS_LOG(ERROR) << "Get Context failed!"; + return FAILED; + } + if (!context::CloseTsd(ms_context)) { + MS_LOG(ERROR) << "Inference CloseTsd failed!"; + return FAILED; + } + MS_LOG_INFO << "End finalize env"; + return SUCCESS; +} + +std::shared_ptr MsModel::LoadModel(const char *model_buf, size_t size, const std::string &device) { + MS_EXCEPTION_IF_NULL(model_buf); + try { + auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size); + return anf_graph; + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference LoadModel failed: " << e.what(); + return nullptr; + } +} + +void MsModel::RegAllOp() { + static std::mutex init_mutex; + static bool Initialized = false; + + std::lock_guard lock(init_mutex); + if (Initialized) { + return; + } + Initialized = true; + auto ms_context_instance = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context_instance); + ms_context_instance->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); + try { + std::shared_ptr guard; + if (Py_IsInitialized() == 0) { + guard = std::make_shared(); + } + py::module c_expression = py::module::import("mindspore._c_expression"); + size_t ops_info_long = c_expression.attr("OpInfoLoaderPy")().attr("get_all_ops_info")().cast(); + auto all_ops_info = reinterpret_cast *>(ops_info_long); + for (auto op_info : *all_ops_info) { + kernel::OpLib::RegOpInfo(std::shared_ptr(op_info)); + } + all_ops_info->clear(); + delete all_ops_info; + } catch (const std::runtime_error &ex) { + MS_LOG_EXCEPTION << ex.what(); + } +} + +Status MsModel::CompileGraph(std::shared_ptr funcGraphPtr) { + MS_ASSERT(session_impl_ != nullptr); + try { + graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); + py::gil_scoped_release gil_release; + return SUCCESS; + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference CompileGraph failed: " << e.what(); + return FAILED; + } +} + +std::vector MsModel::RunGraph(const std::vector &inputs) { + try { + VectorRef outputs; + session_impl_->RunGraph(graph_id_, inputs, &outputs); + return TransformVectorRefToMultiTensor(outputs); + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference Rungraph failed: " << e.what(); + return std::vector(); + } +} + +Status MsModel::InitEnv(const std::unordered_map &other_options) { + RegAllOp(); + auto ms_context = MsContext::GetInstance(); + if (ms_context == nullptr) { + MS_LOG(ERROR) << "Get Context failed!"; + return FAILED; + } + ms_context->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); + ms_context->set_param(MS_CTX_DEVICE_ID, device_id_); + ms_context->set_param(MS_CTX_DEVICE_TARGET, kAscendDevice); + if (!context::OpenTsd(ms_context)) { + MS_LOG(ERROR) << "Session init OpenTsd failed!"; + return FAILED; + } + session_impl_ = session::SessionFactory::Get().Create(kDavinciInferenceDevice); + if (session_impl_ == nullptr) { + MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kDavinciInferenceDevice + << " is available."; + return FAILED; + } + session_impl_->Init(device_id_); + return SUCCESS; +} + +Status MsModel::CheckModelInputs(uint32_t graph_id, const std::vector &inputs) const { + MS_ASSERT(session_impl_ != nullptr); + std::string error_msg; + if (!session_impl_->CheckModelInputs(graph_id, inputs, &error_msg)) { + return Status(INVALID_INPUTS, error_msg); + } + return SUCCESS; +} + +Status MsModel::GetInputsInfo(std::vector *tensor_list) const { + MS_EXCEPTION_IF_NULL(tensor_list); + tensor_list->clear(); + for (size_t i = 0; i < inputs_.size(); i++) { + auto &tensor = inputs_[i]; + Tensor infer_tensor; + infer_tensor.SetName(input_names_[i]); + infer_tensor.SetDataType(TransTypeId2InferDataType(tensor->data_type())); + infer_tensor.SetShape(tensor->shape()); + tensor_list->push_back(infer_tensor); + } + return SUCCESS; +} + +Status MsModel::GetOutputsInfo(std::vector *tensor_list) const { + MS_EXCEPTION_IF_NULL(tensor_list); + tensor_list->clear(); + for (size_t i = 0; i < outputs_.size(); i++) { + auto &tensor = outputs_[i]; + Tensor infer_tensor; + infer_tensor.SetName(output_names_[i]); + infer_tensor.SetDataType(TransTypeId2InferDataType(tensor->data_type())); + infer_tensor.SetShape(tensor->shape()); + tensor_list->push_back(infer_tensor); + } + return SUCCESS; +} + +} // namespace api +} // namespace mindspore diff --git a/mindspore/ccsrc/cxx_api/model/ms/ms_model.h b/mindspore/ccsrc/cxx_api/model/ms/ms_model.h new file mode 100644 index 00000000000..de1dc4de85c --- /dev/null +++ b/mindspore/ccsrc/cxx_api/model/ms/ms_model.h @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_SESSION_H +#define MINDSPORE_CCSRC_SESSION_SESSION_H + +#include +#include +#include +#include +#include +#include + +#include "backend/session/session_basic.h" +#include "ir/anf.h" +#include "include/api/status.h" +#include "cxx_api/model/model_impl.h" + +#ifdef ENABLE_D +#include "runtime/context.h" +#endif + +namespace mindspore { +namespace api { +class MsModel : public ModelImpl { + public: + explicit MsModel(uint32_t device_id); + ~MsModel(); + + Status LoadModel(const Buffer &model_data, ModelType type, + const std::map &options) override; + Status LoadModel(const std::string &file_name, ModelType type, + const std::map &options) override; + Status UnloadModel() override; + + Status Train(const DataSet &dataset, std::map *outputs) override; + Status Eval(const DataSet &dataset, std::map *outputs) override; + Status Predict(const std::map &inputs, std::map *outputs) override; + + Status GetInputsInfo(std::vector *tensor_list) const override; + Status GetOutputsInfo(std::vector *tensor_list) const override; + + Status InitEnv(const std::unordered_map &other_options); + Status FinalizeEnv(); + + private: + std::shared_ptr session_impl_ = nullptr; + uint32_t graph_id_; + std::string device_type_; + int32_t device_id_ = 0; +#ifdef ENABLE_D + rtContext_t context_ = nullptr; +#endif + std::vector inputs_; + std::vector outputs_; + std::vector input_names_; + std::vector output_names_; + bool load_flag_ = false; + + std::shared_ptr LoadModel(const char *model_buf, size_t size, const std::string &device); + Buffer ReadFile(const std::string &file); + static void RegAllOp(); + Status CompileGraph(std::shared_ptr funcGraphPtr); + Status CheckModelInputs(uint32_t graph_id, const std::vector &inputs) const; + std::vector RunGraph(const std::vector &inputs); + Status ExecuteModel(const std::vector &inputs, std::vector *outputs); +}; + +API_REG_MODEL(AscendMS, MsModel); + +} // namespace api +} // namespace mindspore +#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H