support mindpsore::api:Model 910 inference, support 310 model convert in python env
This commit is contained in:
parent
452cb0dd4e
commit
a3b9218919
|
@ -50,9 +50,15 @@ class MS_API Model {
|
|||
Status GetInputsInfo(std::vector<Tensor> *tensor_list) const;
|
||||
Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const;
|
||||
|
||||
static bool CheckModelSupport(const std::string& device_type, ModelType model_type);
|
||||
|
||||
private:
|
||||
std::shared_ptr<ModelImpl> impl_;
|
||||
};
|
||||
|
||||
extern MS_API const char* kDeviceTypeAscendCL;
|
||||
extern MS_API const char* kDeviceTypeAscendMS;
|
||||
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_MODEL_H
|
||||
|
|
|
@ -213,30 +213,5 @@ std::string AscendInferenceSession::InputsInfo(const std::vector<ParameterPtr> &
|
|||
return graph + " " + actual;
|
||||
}
|
||||
|
||||
void AscendInferenceSession::GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs) const {
|
||||
MS_LOG(INFO) << "Start get model inputs, graph id : " << graph_id;
|
||||
auto kernel_graph = GetGraph(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto kernel_graph_inputs = kernel_graph->inputs();
|
||||
vector<ParameterPtr> paras;
|
||||
// find parameters of graph inputs
|
||||
for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) {
|
||||
if (!kernel_graph_inputs[i]->isa<Parameter>()) {
|
||||
MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter.";
|
||||
continue;
|
||||
}
|
||||
auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>();
|
||||
if (!AnfAlgo::IsParameterWeight(parameter)) {
|
||||
vector<int64_t> input_shape;
|
||||
auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0);
|
||||
(void)std::transform(parameter_shape.begin(), parameter_shape.end(), std::back_inserter(input_shape),
|
||||
[](const size_t dim) { return SizeToLong(dim); });
|
||||
auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter);
|
||||
auto data_type = kernel_build_info->GetOutputDeviceType(0);
|
||||
auto ms_tensor = std::make_shared<tensor::Tensor>(data_type, input_shape);
|
||||
inputs->push_back(ms_tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -44,7 +44,6 @@ class AscendInferenceSession : public AscendSession {
|
|||
template <typename T>
|
||||
std::string PrintInputShape(std::vector<T> shape) const;
|
||||
std::string InputsInfo(const std::vector<ParameterPtr> ¶s, const std::vector<tensor::TensorPtr> &inputs) const;
|
||||
void GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs) const override;
|
||||
|
||||
protected:
|
||||
GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override;
|
||||
|
|
|
@ -370,7 +370,8 @@ Status MSInferSession::CheckModelInputs(uint32_t graph_id, const std::vector<ten
|
|||
|
||||
Status MSInferSession::GetModelInputsInfo(uint32_t model_id, std::vector<inference::InferTensor> *tensor_list) const {
|
||||
vector<tensor::TensorPtr> inputs;
|
||||
session_impl_->GetModelInputsInfo(model_id, &inputs);
|
||||
vector<std::string> input_names;
|
||||
session_impl_->GetModelInputsInfo(model_id, &inputs, &input_names);
|
||||
if (inputs.size() == 0) {
|
||||
MS_LOG(ERROR) << "The model inputs is NULL";
|
||||
return FAILED;
|
||||
|
|
|
@ -34,6 +34,8 @@
|
|||
#include "ir/func_graph_cloner.h"
|
||||
#include "utils/utils.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "mindspore/core/base/base_ref_utils.h"
|
||||
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "ps/worker.h"
|
||||
#include "ps/util.h"
|
||||
|
@ -1089,6 +1091,61 @@ void SessionBasic::CreateOutputTensors(const GraphId &graph_id, const std::vecto
|
|||
}
|
||||
}
|
||||
|
||||
void SessionBasic::GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs,
|
||||
std::vector<std::string> *inputs_name) const {
|
||||
MS_LOG(INFO) << "Start get model inputs, graph id : " << graph_id;
|
||||
auto kernel_graph = GetGraph(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(inputs);
|
||||
MS_EXCEPTION_IF_NULL(inputs_name);
|
||||
auto kernel_graph_inputs = kernel_graph->inputs();
|
||||
vector<ParameterPtr> paras;
|
||||
// find parameters of graph inputs
|
||||
for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) {
|
||||
if (!kernel_graph_inputs[i]->isa<Parameter>()) {
|
||||
MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter.";
|
||||
continue;
|
||||
}
|
||||
auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>();
|
||||
if (!AnfAlgo::IsParameterWeight(parameter)) {
|
||||
vector<int64_t> input_shape;
|
||||
auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0);
|
||||
(void)std::transform(parameter_shape.begin(), parameter_shape.end(), std::back_inserter(input_shape),
|
||||
[](const size_t dim) { return SizeToLong(dim); });
|
||||
auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter);
|
||||
auto data_type = kernel_build_info->GetOutputDeviceType(0);
|
||||
auto ms_tensor = std::make_shared<tensor::Tensor>(data_type, input_shape);
|
||||
inputs->push_back(ms_tensor);
|
||||
inputs_name->push_back(parameter->name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SessionBasic::GetModelOutputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *outputs,
|
||||
std::vector<std::string> *output_names) const {
|
||||
std::vector<tensor::TensorPtr> inputs;
|
||||
std::vector<std::string> input_names;
|
||||
GetModelInputsInfo(graph_id, &inputs, &input_names);
|
||||
|
||||
auto kernel_graph = GetGraph(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
MS_EXCEPTION_IF_NULL(output_names);
|
||||
|
||||
VectorRef vector_outputs;
|
||||
std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node;
|
||||
auto anf_outputs = kernel_graph->outputs();
|
||||
for (auto &item : anf_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]";
|
||||
vector_outputs.emplace_back(CreateNodeOutputTensors(item, kernel_graph, inputs, &tensor_to_node));
|
||||
}
|
||||
*outputs = TransformVectorRefToMultiTensor(vector_outputs);
|
||||
for (size_t i = 0; i < outputs->size(); i++) {
|
||||
output_names->push_back("output" + std::to_string(i));
|
||||
}
|
||||
}
|
||||
|
||||
void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
|
||||
MS_EXCEPTION_IF_NULL(callback);
|
||||
summary_callback_ = callback;
|
||||
|
|
|
@ -102,7 +102,10 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
std::string *error_msg) const {
|
||||
return true;
|
||||
}
|
||||
virtual void GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs) const {}
|
||||
void GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs,
|
||||
std::vector<std::string> *inputs_name) const;
|
||||
void GetModelOutputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *outputs,
|
||||
std::vector<std::string> *outputs_name) const;
|
||||
std::vector<tensor::TensorPtr> GetInputNeedLockTensors(const GraphId &graph_id,
|
||||
const std::vector<tensor::TensorPtr> &inputs);
|
||||
// Get graph by graph id, if not exist return null ptr
|
||||
|
|
|
@ -6,22 +6,25 @@ set(LOAD_ONNX_SRC
|
|||
file(GLOB_RECURSE API_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR} "ops/*.cc")
|
||||
|
||||
if (ENABLE_ACL)
|
||||
file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/acl/*.cc")
|
||||
file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/acl/*.cc" "model/model_converter_utils/*.cc")
|
||||
elseif (ENABLE_D)
|
||||
file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/ms/*.cc")
|
||||
endif ()
|
||||
|
||||
set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cell.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc
|
||||
${API_ACL_SRC}
|
||||
${API_OPS_SRC}
|
||||
${LOAD_ONNX_SRC})
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cell.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc
|
||||
${API_MS_INFER_SRC}
|
||||
${API_ACL_SRC}
|
||||
${API_OPS_SRC}
|
||||
${LOAD_ONNX_SRC})
|
||||
|
||||
add_library(mindspore_shared_lib SHARED ${MSLIB_SRC})
|
||||
set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore PUBLIC_HEADER "${API_INCLUDE}")
|
||||
|
||||
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
|
||||
-Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_gvar mindspore::protobuf)
|
||||
-Wl,--whole-archive mindspore -Wl,--no-whole-archive proto_input mindspore_gvar mindspore::protobuf)
|
||||
|
||||
if (ENABLE_CPU)
|
||||
target_link_libraries(mindspore_shared_lib PRIVATE mindspore::dnnl mindspore::mkldnn)
|
||||
|
@ -58,5 +61,13 @@ if (ENABLE_ACL)
|
|||
find_library(acl_runtime libruntime.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
||||
find_library(ge_compiler libge_compiler.so ${ATC_DIR}/lib64 ${ATLAS_ATC_DIR}/lib64)
|
||||
target_link_libraries(mindspore_shared_lib PRIVATE ${acl} ${acl_retr} ${acl_cblas} ${acl_dvpp} ${acl_runtime}
|
||||
${ge_compiler} mindspore::jpeg_turbo)
|
||||
${ge_compiler} mindspore::jpeg_turbo)
|
||||
endif ()
|
||||
|
||||
|
||||
# Before build inference
|
||||
if (ENABLE_D)
|
||||
find_library(adump_server libadump_server.a ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
target_link_libraries(mindspore_shared_lib PRIVATE ${adump_server})
|
||||
endif ()
|
||||
|
||||
|
|
|
@ -91,7 +91,6 @@ Status AclModel::InitEnv() {
|
|||
MS_LOG(ERROR) << "DVPP init resource failed";
|
||||
return FAILED;
|
||||
}
|
||||
ModelConverter::RegAllOp();
|
||||
|
||||
MS_LOG(INFO) << "Init acl success, device id " << device_id_;
|
||||
init_flag_ = true;
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
|
||||
#include "graph/model.h"
|
||||
#include "cxx_api/model/model_converter_utils/multi_process.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
@ -238,6 +239,131 @@ Buffer ModelConverter::ReadFile(const std::string &file) {
|
|||
}
|
||||
|
||||
Buffer ModelConverter::LoadMindIR(const Buffer &model_data) {
|
||||
if (Py_IsInitialized() == 0) {
|
||||
MS_LOG_INFO << "Call LoadMindIRInner directly";
|
||||
return LoadMindIRInner(model_data);
|
||||
}
|
||||
MultiProcess multi_process;
|
||||
Buffer buffer_ret;
|
||||
auto parent_process = [&model_data, &buffer_ret](MultiProcess *multi_process) -> Status {
|
||||
MS_EXCEPTION_IF_NULL(multi_process);
|
||||
// send original model to child
|
||||
auto status = multi_process->SendMsg(model_data.Data(), model_data.DataSize());
|
||||
if (!status.IsSuccess()) {
|
||||
MS_LOG_ERROR << "Send original model to child process failed";
|
||||
return FAILED;
|
||||
}
|
||||
// receive convert model result from child
|
||||
CreateBufferCall call = [&buffer_ret](size_t msg_len) -> uint8_t * {
|
||||
buffer_ret.ResizeData(msg_len);
|
||||
return reinterpret_cast<uint8_t *>(buffer_ret.MutableData());
|
||||
};
|
||||
status = multi_process->ReceiveMsg(call);
|
||||
if (!status.IsSuccess()) {
|
||||
MS_LOG_ERROR << "Receive result model from child process failed";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
};
|
||||
auto child_process = [this](MultiProcess *multi_process) -> Status {
|
||||
MS_EXCEPTION_IF_NULL(multi_process);
|
||||
// receive original model from parent
|
||||
Buffer model;
|
||||
CreateBufferCall call = [&model](size_t msg_len) -> uint8_t * {
|
||||
model.ResizeData(msg_len);
|
||||
return reinterpret_cast<uint8_t *>(model.MutableData());
|
||||
};
|
||||
auto status = multi_process->ReceiveMsg(call);
|
||||
if (!status.IsSuccess()) {
|
||||
MS_LOG_ERROR << "Receive original model from parent process failed";
|
||||
return FAILED;
|
||||
}
|
||||
Buffer model_result = LoadMindIRInner(model);
|
||||
if (model_result.DataSize() == 0) {
|
||||
MS_LOG_ERROR << "Convert model from MindIR to OM failed";
|
||||
return FAILED;
|
||||
}
|
||||
// send result model to parent
|
||||
status = multi_process->SendMsg(model_result.Data(), model_result.DataSize());
|
||||
if (!status.IsSuccess()) {
|
||||
MS_LOG_ERROR << "Send result model to parent process failed";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
};
|
||||
auto status = multi_process.MainProcess(parent_process, child_process);
|
||||
if (!status.IsSuccess()) {
|
||||
MS_LOG_ERROR << "Convert MindIR model to OM model failed";
|
||||
} else {
|
||||
MS_LOG_INFO << "Convert MindIR model to OM model success";
|
||||
}
|
||||
return buffer_ret;
|
||||
}
|
||||
|
||||
Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) {
|
||||
if (Py_IsInitialized() == 0) {
|
||||
MS_LOG_INFO << "Call LoadAscendIRInner directly";
|
||||
return LoadAscendIRInner(model_data);
|
||||
}
|
||||
MultiProcess multi_process;
|
||||
Buffer buffer_ret;
|
||||
auto parent_process = [&model_data, &buffer_ret](MultiProcess *multi_process) -> Status {
|
||||
MS_EXCEPTION_IF_NULL(multi_process);
|
||||
// send original model to child
|
||||
auto status = multi_process->SendMsg(model_data.Data(), model_data.DataSize());
|
||||
if (!status.IsSuccess()) {
|
||||
MS_LOG_ERROR << "Send original model to child process failed";
|
||||
return FAILED;
|
||||
}
|
||||
// receive convert model result from child
|
||||
CreateBufferCall call = [&buffer_ret](size_t msg_len) -> uint8_t * {
|
||||
buffer_ret.ResizeData(msg_len);
|
||||
return reinterpret_cast<uint8_t *>(buffer_ret.MutableData());
|
||||
};
|
||||
status = multi_process->ReceiveMsg(call);
|
||||
if (!status.IsSuccess()) {
|
||||
MS_LOG_ERROR << "Receive result model from child process failed";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
};
|
||||
auto child_process = [this](MultiProcess *multi_process) -> Status {
|
||||
MS_EXCEPTION_IF_NULL(multi_process);
|
||||
// receive original model from parent
|
||||
Buffer model;
|
||||
CreateBufferCall call = [&model](size_t msg_len) -> uint8_t * {
|
||||
model.ResizeData(msg_len);
|
||||
return reinterpret_cast<uint8_t *>(model.MutableData());
|
||||
};
|
||||
auto status = multi_process->ReceiveMsg(call);
|
||||
if (!status.IsSuccess()) {
|
||||
MS_LOG_ERROR << "Receive original model from parent process failed";
|
||||
return FAILED;
|
||||
}
|
||||
Buffer model_result = LoadAscendIRInner(model);
|
||||
if (model_result.DataSize() == 0) {
|
||||
MS_LOG_ERROR << "Convert model from AIR to OM failed";
|
||||
return FAILED;
|
||||
}
|
||||
// send result model to parent
|
||||
status = multi_process->SendMsg(model_result.Data(), model_result.DataSize());
|
||||
if (!status.IsSuccess()) {
|
||||
MS_LOG_ERROR << "Send result model to parent process failed";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
};
|
||||
auto status = multi_process.MainProcess(parent_process, child_process);
|
||||
if (!status.IsSuccess()) {
|
||||
MS_LOG_ERROR << "Convert AIR model to OM model failed";
|
||||
} else {
|
||||
MS_LOG_INFO << "Convert AIR model to OM model success";
|
||||
}
|
||||
return buffer_ret;
|
||||
}
|
||||
|
||||
Buffer ModelConverter::LoadMindIRInner(const Buffer &model_data) {
|
||||
RegAllOp();
|
||||
auto func_graph = ConvertMindIrToFuncGraph(model_data);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Convert MindIR to FuncGraph failed.";
|
||||
|
@ -259,7 +385,8 @@ Buffer ModelConverter::LoadMindIR(const Buffer &model_data) {
|
|||
return om_data;
|
||||
}
|
||||
|
||||
Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) {
|
||||
Buffer ModelConverter::LoadAscendIRInner(const Buffer &model_data) {
|
||||
RegAllOp();
|
||||
ge::Model load_model = ge::Model("loadmodel", "version2");
|
||||
ge::Status ret =
|
||||
ge::Model::Load(reinterpret_cast<const uint8_t *>(model_data.Data()), model_data.DataSize(), load_model);
|
||||
|
|
|
@ -45,6 +45,9 @@ class ModelConverter {
|
|||
transform::DfGraphPtr ConvertFuncGraphToAIR(const FuncGraphPtr &anf_graph);
|
||||
Buffer BuildAirModel(const transform::DfGraphPtr &graph, const std::map<std::string, std::string> &acl_options);
|
||||
AclModelOptions *options_;
|
||||
|
||||
Buffer LoadMindIRInner(const Buffer &model_data);
|
||||
Buffer LoadAscendIRInner(const Buffer &model_data);
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
|
||||
|
|
|
@ -18,6 +18,9 @@
|
|||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
const char *kDeviceTypeAscendCL = "AscendCL";
|
||||
const char *kDeviceTypeAscendMS = "AscendMS";
|
||||
|
||||
Status Model::LoadModel(const Buffer &model_data, ModelType type, const std::map<std::string, std::string> &options) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->LoadModel(model_data, type, options);
|
||||
|
@ -95,4 +98,9 @@ Model::Model(NetWork network, const std::string &device_type, uint32_t device_id
|
|||
}
|
||||
|
||||
Model::~Model() {}
|
||||
|
||||
bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) {
|
||||
return ModelFactory::Instance().CheckModelSupport(device_type, model_type);
|
||||
}
|
||||
|
||||
} // namespace mindspore::api
|
||||
|
|
|
@ -0,0 +1,207 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cxx_api/model/model_converter_utils/multi_process.h"
|
||||
#include <unistd.h>
|
||||
#include <sys/wait.h>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
#include "mindspore/core/utils/log_adapter.h"
|
||||
#include "cxx_api/model/model_converter_utils/shared_memory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
|
||||
namespace {
|
||||
uint64_t kSharedMemorySize = 100ull << 20; // 100 MB
|
||||
}
|
||||
|
||||
MultiProcess::MultiProcess() = default;
|
||||
|
||||
MultiProcess::~MultiProcess() = default;
|
||||
|
||||
Status MultiProcess::MainProcess(ProcessFuncCall parent_process, ProcessFuncCall child_process) {
|
||||
MS_EXCEPTION_IF_NULL(parent_process);
|
||||
MS_EXCEPTION_IF_NULL(child_process);
|
||||
Status ret;
|
||||
memory_size_ = kSharedMemorySize; // 100 MB
|
||||
SharedMemory shared_memory;
|
||||
ret = shared_memory.Create(memory_size_);
|
||||
if (!ret.IsSuccess()) {
|
||||
MS_LOG_ERROR << "Create shared memory failed";
|
||||
return ret;
|
||||
}
|
||||
pid_t pid = fork();
|
||||
if (pid < 0) {
|
||||
shared_memory.Destroy();
|
||||
MS_LOG_ERROR << "Fork process to convert model failed";
|
||||
return FAILED;
|
||||
}
|
||||
ret = shared_memory.Attach();
|
||||
if (!ret.IsSuccess()) {
|
||||
MS_LOG_ERROR << "Process attach shared memory failed, pid " << pid;
|
||||
return ret;
|
||||
}
|
||||
shmat_addr_ = shared_memory.GetSharedMemoryAddr();
|
||||
if (shmat_addr_ == nullptr) {
|
||||
MS_LOG_ERROR << "Get shared memory failed";
|
||||
return ret;
|
||||
}
|
||||
shmat_data_addr_ = shmat_addr_ + sizeof(MessageFlag) * 2;
|
||||
shmat_data_max_size_ = memory_size_ - (shmat_data_addr_ - shmat_addr_);
|
||||
|
||||
MS_LOG_INFO << "Shm addr " << (uint64_t)shmat_addr_;
|
||||
if (pid == 0) {
|
||||
ChildProcess(child_process);
|
||||
shared_memory.Detach();
|
||||
MS_LOG_INFO << "Model converter: child process exit";
|
||||
exit(0);
|
||||
} else { // parent process
|
||||
ret = ParentProcess(parent_process);
|
||||
shared_memory.Detach();
|
||||
int status;
|
||||
wait(&status);
|
||||
shared_memory.Destroy();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
Status MultiProcess::ParentProcess(ProcessFuncCall parent_process) {
|
||||
auto parent_msg = reinterpret_cast<MessageFlag *>(shmat_addr_);
|
||||
auto child_msg = reinterpret_cast<MessageFlag *>(shmat_addr_ + sizeof(MessageFlag));
|
||||
send_msg_ = parent_msg;
|
||||
receive_msg_ = child_msg;
|
||||
std::thread heartbeat_thread(MultiProcess::HeartbeatThreadFunc, this);
|
||||
Status ret;
|
||||
try {
|
||||
ret = parent_process(this);
|
||||
if (!ret.IsSuccess()) {
|
||||
MS_LOG_ERROR << "Parent process process failed";
|
||||
}
|
||||
} catch (const std::runtime_error &ex) {
|
||||
MS_LOG_ERROR << "Catch parent process runtime error: " << ex.what();
|
||||
ret = FAILED;
|
||||
}
|
||||
stopped_ = true;
|
||||
send_msg_->stop = true;
|
||||
heartbeat_thread.join();
|
||||
return ret;
|
||||
}
|
||||
|
||||
void MultiProcess::ChildProcess(ProcessFuncCall child_process) {
|
||||
auto parent_msg = reinterpret_cast<MessageFlag *>(shmat_addr_);
|
||||
auto child_msg = reinterpret_cast<MessageFlag *>(shmat_addr_ + sizeof(MessageFlag));
|
||||
send_msg_ = child_msg;
|
||||
receive_msg_ = parent_msg;
|
||||
std::thread heartbeat_thread(MultiProcess::HeartbeatThreadFunc, this);
|
||||
try {
|
||||
auto ret = child_process(this);
|
||||
if (!ret.IsSuccess()) {
|
||||
MS_LOG_ERROR << "Child process process failed";
|
||||
}
|
||||
} catch (const std::runtime_error &ex) {
|
||||
MS_LOG_ERROR << "Catch child process runtime error: " << ex.what();
|
||||
}
|
||||
stopped_ = true;
|
||||
send_msg_->stop = true;
|
||||
heartbeat_thread.join();
|
||||
}
|
||||
|
||||
Status MultiProcess::SendMsg(const void *buffer, uint64_t msg_len) {
|
||||
MS_LOG_INFO << "Start to send message to peer process, msg len " << msg_len;
|
||||
send_msg_->msg_total_len = msg_len;
|
||||
uint64_t cur_offset = 0;
|
||||
while (msg_len > cur_offset) {
|
||||
uint64_t sub_msg_len = std::min(msg_len - cur_offset, shmat_data_max_size_);
|
||||
|
||||
memcpy_s(shmat_data_addr_, shmat_data_max_size_, static_cast<const uint8_t *>(buffer) + cur_offset, sub_msg_len);
|
||||
cur_offset += sub_msg_len;
|
||||
|
||||
send_msg_->msg_len = sub_msg_len;
|
||||
send_msg_->read_finish_flag = false;
|
||||
send_msg_->read_ready_flag = true;
|
||||
MS_LOG_INFO << "Send start " << cur_offset << ", msg len " << sub_msg_len << ", total len " << msg_len;
|
||||
while (!send_msg_->read_finish_flag && !peer_stopped_) {
|
||||
usleep(1000); // 1ms
|
||||
}
|
||||
if (peer_stopped_) {
|
||||
if (!send_msg_->read_finish_flag) {
|
||||
return FAILED;
|
||||
}
|
||||
break;
|
||||
}
|
||||
MS_LOG_INFO << "Send end " << cur_offset << ", msg len " << sub_msg_len << ", total len " << msg_len;
|
||||
}
|
||||
MS_LOG_INFO << "End to send message to peer process, msg len " << msg_len;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MultiProcess::ReceiveMsg(CreateBufferCall create_buffer_call) {
|
||||
uint64_t cur_offset = 0;
|
||||
uint8_t *msg_buffer = nullptr;
|
||||
uint64_t msg_len = 0;
|
||||
do {
|
||||
MS_LOG_INFO << "Receive start from " << cur_offset;
|
||||
while (!receive_msg_->read_ready_flag && !peer_stopped_) {
|
||||
usleep(1000); // 1ms
|
||||
}
|
||||
if (peer_stopped_) {
|
||||
return FAILED;
|
||||
}
|
||||
if (msg_buffer == nullptr) {
|
||||
msg_len = receive_msg_->msg_total_len;
|
||||
msg_buffer = create_buffer_call(msg_len);
|
||||
}
|
||||
memcpy_s(msg_buffer + cur_offset, msg_len - cur_offset, shmat_data_addr_, receive_msg_->msg_len);
|
||||
cur_offset += receive_msg_->msg_len;
|
||||
receive_msg_->read_ready_flag = false;
|
||||
receive_msg_->read_finish_flag = true;
|
||||
MS_LOG_INFO << "Receive end, current length " << cur_offset << ", total length " << msg_len << std::endl;
|
||||
} while (msg_len > cur_offset);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void MultiProcess::HeartbeatThreadFunc(MultiProcess *multi_process) { multi_process->HeartbeatThreadFuncInner(); }
|
||||
|
||||
void MultiProcess::HeartbeatThreadFuncInner() {
|
||||
uint64_t last_beat_cnt = 0;
|
||||
uint64_t repeat_cnt = 0;
|
||||
while (!stopped_) {
|
||||
if (receive_msg_->stop) {
|
||||
peer_stopped_ = true;
|
||||
MS_LOG_WARNING << "Peer stopped";
|
||||
break;
|
||||
}
|
||||
uint64_t heartbeat_gap = receive_msg_->heartbeat - last_beat_cnt;
|
||||
if (heartbeat_gap > 0 && heartbeat_gap < 1024) {
|
||||
last_beat_cnt = receive_msg_->heartbeat;
|
||||
repeat_cnt = 0;
|
||||
} else {
|
||||
repeat_cnt++;
|
||||
if (repeat_cnt > 30) { // 30*100ms = 3s no reply
|
||||
peer_stopped_ = true;
|
||||
MS_LOG_WARNING << "Peer stopped";
|
||||
break;
|
||||
}
|
||||
}
|
||||
send_msg_->heartbeat += 1;
|
||||
usleep(100000); // sleep 100 ms
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,68 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H
|
||||
#define MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H
|
||||
#include <iostream>
|
||||
#include <functional>
|
||||
#include "include/api/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
|
||||
struct MessageFlag {
|
||||
uint64_t heartbeat = 0;
|
||||
uint64_t stop = false;
|
||||
uint64_t msg_len = 0;
|
||||
uint64_t msg_total_len = 0;
|
||||
uint64_t read_ready_flag = false;
|
||||
uint64_t read_finish_flag = false;
|
||||
};
|
||||
|
||||
class MultiProcess;
|
||||
using ProcessFuncCall = std::function<Status(MultiProcess *multi_process)>;
|
||||
using CreateBufferCall = std::function<uint8_t *(size_t msg_len)>;
|
||||
|
||||
class MultiProcess {
|
||||
public:
|
||||
MultiProcess();
|
||||
~MultiProcess();
|
||||
|
||||
Status MainProcess(ProcessFuncCall parent_process, ProcessFuncCall child_process);
|
||||
Status SendMsg(const void *buffer, uint64_t msg_len);
|
||||
Status ReceiveMsg(CreateBufferCall create_buffer_call);
|
||||
|
||||
private:
|
||||
uint8_t *shmat_addr_ = nullptr;
|
||||
uint8_t *shmat_data_addr_ = nullptr;
|
||||
uint64_t shmat_data_max_size_ = 0;
|
||||
uint64_t memory_size_ = 0;
|
||||
|
||||
bool peer_stopped_ = false;
|
||||
bool stopped_ = false;
|
||||
MessageFlag *send_msg_ = nullptr;
|
||||
MessageFlag *receive_msg_ = nullptr;
|
||||
|
||||
static void HeartbeatThreadFunc(MultiProcess *multi_process);
|
||||
void HeartbeatThreadFuncInner();
|
||||
Status ParentProcess(ProcessFuncCall parent_process);
|
||||
void ChildProcess(ProcessFuncCall child_process);
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cxx_api/model/model_converter_utils/shared_memory.h"
|
||||
#include <sys/shm.h>
|
||||
#include <sys/stat.h>
|
||||
#include <string>
|
||||
#include "mindspore/core/utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
|
||||
Status SharedMemory::Create(uint64_t memory_size) {
|
||||
auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP;
|
||||
shm_id_ = shmget(IPC_PRIVATE, memory_size, IPC_CREAT | IPC_EXCL | access_mode);
|
||||
if (shm_id_ == -1) {
|
||||
MS_LOG_ERROR << "Shared memory creation failed. Errno " + std::to_string(errno);
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG_INFO << "shmget success, shm id " << shm_id_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status SharedMemory::Attach() {
|
||||
void *shmat_addr = shmat(shm_id_, nullptr, 0);
|
||||
if (shmat_addr == reinterpret_cast<void *>(-1)) {
|
||||
MS_LOG_ERROR << "Shared memory attach failed. Errno " + std::to_string(errno);
|
||||
return FAILED;
|
||||
}
|
||||
shmat_addr_ = reinterpret_cast<uint8_t *>(shmat_addr);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void SharedMemory::Detach() {
|
||||
if (shmat_addr_) {
|
||||
auto err = shmdt(shmat_addr_);
|
||||
if (err == -1) {
|
||||
MS_LOG_ERROR << "Shared memory detach failed. Errno " + std::to_string(errno);
|
||||
return;
|
||||
}
|
||||
}
|
||||
shmat_addr_ = nullptr;
|
||||
}
|
||||
|
||||
void SharedMemory::Destroy() {
|
||||
// Remove the shared memory and never mind about the return code.
|
||||
auto err = shmctl(shm_id_, IPC_RMID, nullptr);
|
||||
if (err == -1) {
|
||||
std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id_);
|
||||
errMsg += ". Errno :" + std::to_string(errno);
|
||||
errMsg += "\nPlesae remove it manually using ipcrm -m command";
|
||||
MS_LOG_ERROR << errMsg;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H
|
||||
#define MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H
|
||||
#include <iostream>
|
||||
#include "include/api/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
|
||||
class SharedMemory {
|
||||
public:
|
||||
Status Create(uint64_t memory_size);
|
||||
Status Attach();
|
||||
void Detach();
|
||||
void Destroy();
|
||||
uint8_t *GetSharedMemoryAddr() { return shmat_addr_; }
|
||||
|
||||
private:
|
||||
int shm_id_ = -1;
|
||||
uint8_t *shmat_addr_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H
|
|
@ -70,6 +70,12 @@ class ModelFactory {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
bool CheckModelSupport(const std::string &device_type, ModelType /*model_type*/) {
|
||||
return std::any_of(
|
||||
model_creators_.begin(), model_creators_.end(),
|
||||
[&device_type](const std::pair<std::string, ModelCreator> &item) { return item.first == device_type; });
|
||||
}
|
||||
|
||||
private:
|
||||
ModelFactory() = default;
|
||||
~ModelFactory() = default;
|
||||
|
@ -86,7 +92,7 @@ class ModelRegistrar {
|
|||
|
||||
#define API_REG_MODEL(DEVICE_NAME, MODEL_CLASS) \
|
||||
static const ModelRegistrar g_api_model_registrar__##DEVICE_NAME##_##_reg( \
|
||||
#DEVICE_NAME, [](uint32_t device_id) { return std::make_shared<MODEL_CLASS>(device_id); });
|
||||
kDeviceType##DEVICE_NAME, [](uint32_t device_id) { return std::make_shared<MODEL_CLASS>(device_id); });
|
||||
|
||||
} // namespace mindspore::api
|
||||
|
||||
|
|
|
@ -0,0 +1,418 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cxx_api/model/ms/ms_model.h"
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
|
||||
#include "utils/load_onnx/anf_converter.h"
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "backend/session/session_factory.h"
|
||||
#include "backend/session/executor_manager.h"
|
||||
#include "base/base_ref_utils.h"
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
#include "utils/context/context_extends.h"
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/embed.h"
|
||||
|
||||
#ifdef ENABLE_D
|
||||
#include "utils/ms_context.h"
|
||||
#endif
|
||||
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
|
||||
MsModel::MsModel(uint32_t device_id) : device_id_(device_id) {}
|
||||
MsModel::~MsModel() = default;
|
||||
|
||||
TypeId TransInferDataType2TypeId(DataType data_type) {
|
||||
const std::map<api::DataType, TypeId> type2id_map{
|
||||
{api::kMsUnknown, TypeId::kNumberTypeBegin}, {api::kMsBool, TypeId::kNumberTypeBool},
|
||||
{api::kMsInt8, TypeId::kNumberTypeInt8}, {api::kMsUint8, TypeId::kNumberTypeUInt8},
|
||||
{api::kMsInt16, TypeId::kNumberTypeInt16}, {api::kMsUint16, TypeId::kNumberTypeUInt16},
|
||||
{api::kMsInt32, TypeId::kNumberTypeInt32}, {api::kMsUint32, TypeId::kNumberTypeUInt32},
|
||||
{api::kMsInt64, TypeId::kNumberTypeInt64}, {api::kMsUint64, TypeId::kNumberTypeUInt64},
|
||||
{api::kMsFloat16, TypeId::kNumberTypeFloat16}, {api::kMsFloat32, TypeId::kNumberTypeFloat32},
|
||||
{api::kMsFloat64, TypeId::kNumberTypeFloat64},
|
||||
};
|
||||
auto it = type2id_map.find(data_type);
|
||||
if (it == type2id_map.end()) {
|
||||
MS_LOG_WARNING << "Unsupported MSI data type " << data_type;
|
||||
return TypeId::kNumberTypeBegin;
|
||||
} else {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
DataType TransTypeId2InferDataType(TypeId type_id) {
|
||||
const std::map<TypeId, api::DataType> id2type_map{
|
||||
{TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool},
|
||||
{TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8},
|
||||
{TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16},
|
||||
{TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32},
|
||||
{TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64},
|
||||
{TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16},
|
||||
{TypeId::kNumberTypeFloat32, api::kMsFloat32},
|
||||
};
|
||||
auto it = id2type_map.find(type_id);
|
||||
if (it == id2type_map.end()) {
|
||||
MS_LOG_WARNING << "Unsupported data id " << type_id;
|
||||
return api::kMsUnknown;
|
||||
} else {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
Buffer MsModel::ReadFile(const std::string &file) {
|
||||
if (file.empty()) {
|
||||
MS_LOG(ERROR) << "file is nullptr";
|
||||
return Buffer();
|
||||
}
|
||||
std::ifstream ifs(file);
|
||||
if (!ifs.good()) {
|
||||
MS_LOG(ERROR) << "file: " << file << " is not exist";
|
||||
return Buffer();
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
MS_LOG(ERROR) << "file: " << file << "open failed";
|
||||
return Buffer();
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
Buffer buffer;
|
||||
buffer.ResizeData(size);
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(static_cast<char *>(buffer.MutableData()), size);
|
||||
ifs.close();
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
Status MsModel::LoadModel(const Buffer &model_data, ModelType type, const std::map<std::string, std::string> &options) {
|
||||
auto status = InitEnv({});
|
||||
if (status != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Init env failed";
|
||||
return FAILED;
|
||||
}
|
||||
std::shared_ptr<FuncGraph> anf_graph;
|
||||
try {
|
||||
anf_graph =
|
||||
lite::AnfConverter::RunAnfConverter(static_cast<const char *>(model_data.Data()), model_data.DataSize());
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Inference LoadModel failed";
|
||||
return FAILED;
|
||||
}
|
||||
Status ret = CompileGraph(anf_graph);
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Compile graph model failed";
|
||||
return FAILED;
|
||||
}
|
||||
session_impl_->GetModelInputsInfo(graph_id_, &inputs_, &input_names_);
|
||||
session_impl_->GetModelOutputsInfo(graph_id_, &outputs_, &output_names_);
|
||||
if (inputs_.empty() || inputs_.size() != input_names_.size()) {
|
||||
MS_LOG_ERROR << "Get model inputs info failed";
|
||||
return FAILED;
|
||||
}
|
||||
if (outputs_.empty() || outputs_.size() != output_names_.size()) {
|
||||
MS_LOG_ERROR << "Get model outputs info failed";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Load model success";
|
||||
|
||||
#ifdef ENABLE_D
|
||||
// set d context
|
||||
rtError_t rt_ret = rtCtxGetCurrent(&context_);
|
||||
if (rt_ret != RT_ERROR_NONE || context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "the ascend device context is null";
|
||||
return FAILED;
|
||||
}
|
||||
#endif
|
||||
load_flag_ = true;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsModel::LoadModel(const std::string &file_name, ModelType type,
|
||||
const std::map<std::string, std::string> &options) {
|
||||
auto graphBuf = ReadFile(file_name);
|
||||
if (graphBuf.DataSize() == 0) {
|
||||
MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str();
|
||||
return FAILED;
|
||||
}
|
||||
auto status = LoadModel(graphBuf, type, options);
|
||||
if (status != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsModel::UnloadModel() {
|
||||
if (!load_flag_) {
|
||||
MS_LOG_ERROR << "Model has not been loaded";
|
||||
return FAILED;
|
||||
}
|
||||
FinalizeEnv();
|
||||
load_flag_ = false;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsModel::Train(const DataSet &, std::map<std::string, Buffer> *) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Status MsModel::Eval(const DataSet &, std::map<std::string, Buffer> *) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
Status MsModel::Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
if (!load_flag_) {
|
||||
MS_LOG(ERROR) << "No model is loaded, predict failed.";
|
||||
return FAILED;
|
||||
}
|
||||
if (inputs.size() != inputs_.size()) {
|
||||
MS_LOG(ERROR) << "inputs count not match, required count " << inputs_.size() << ", given count " << inputs.size();
|
||||
return INVALID_INPUTS;
|
||||
}
|
||||
std::vector<Buffer> request;
|
||||
std::vector<Buffer> reply;
|
||||
for (size_t i = 0; i < inputs_.size(); ++i) {
|
||||
const auto &input_name = input_names_[i];
|
||||
auto iter = inputs.find(input_name);
|
||||
if (iter == inputs.end()) {
|
||||
MS_LOG(ERROR) << "Model missing input " << input_name;
|
||||
return INVALID_INPUTS;
|
||||
}
|
||||
|
||||
if (iter->second.DataSize() != inputs_[i]->Size()) {
|
||||
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_[i]->Size() << ", given count "
|
||||
<< iter->second.DataSize();
|
||||
return INVALID_INPUTS;
|
||||
}
|
||||
request.push_back(iter->second);
|
||||
}
|
||||
if (ExecuteModel(request, &reply) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Execute Model Failed";
|
||||
return FAILED;
|
||||
}
|
||||
if (outputs_.size() != reply.size()) {
|
||||
MS_LOG(ERROR) << "Predict output size " << reply.size() << " not match output size got from model info "
|
||||
<< outputs_.size();
|
||||
return FAILED;
|
||||
}
|
||||
outputs->clear();
|
||||
for (size_t i = 0; i < reply.size(); i++) {
|
||||
outputs->emplace(output_names_[i], reply[i]);
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsModel::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) {
|
||||
MS_EXCEPTION_IF_NULL(reply);
|
||||
#ifdef ENABLE_D
|
||||
if (context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "rtCtx is nullptr";
|
||||
return FAILED;
|
||||
}
|
||||
rtError_t rt_ret = rtCtxSetCurrent(context_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "set Ascend rtCtx failed";
|
||||
return FAILED;
|
||||
}
|
||||
#endif
|
||||
vector<tensor::TensorPtr> inputs;
|
||||
for (size_t i = 0; i < request.size(); i++) {
|
||||
auto &item = request[i];
|
||||
auto input = inputs_[i];
|
||||
if (input->Size() != item.DataSize()) {
|
||||
MS_LOG(ERROR) << "Predict input " << i << " data size " << item.DataSize() << " not match model input data size "
|
||||
<< input->Size();
|
||||
return FAILED;
|
||||
}
|
||||
auto ret = memcpy_s(input->data_c(), input->Size(), item.Data(), item.DataSize());
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Tensor copy failed";
|
||||
return FAILED;
|
||||
}
|
||||
inputs.push_back(input);
|
||||
}
|
||||
vector<tensor::TensorPtr> outputs = RunGraph(inputs);
|
||||
if (outputs.empty()) {
|
||||
MS_LOG(ERROR) << "Execute Model Failed";
|
||||
return FAILED;
|
||||
}
|
||||
reply->clear();
|
||||
std::transform(outputs.begin(), outputs.end(), std::back_inserter(*reply),
|
||||
[](const tensor::TensorPtr &tensor) { return Buffer(tensor->data_c(), tensor->Size()); });
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsModel::FinalizeEnv() {
|
||||
MS_LOG_INFO << "Start finalize env";
|
||||
py::gil_scoped_acquire acquire;
|
||||
session::ExecutorManager::Instance().Clear();
|
||||
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
if (ms_context == nullptr) {
|
||||
MS_LOG(ERROR) << "Get Context failed!";
|
||||
return FAILED;
|
||||
}
|
||||
if (!context::CloseTsd(ms_context)) {
|
||||
MS_LOG(ERROR) << "Inference CloseTsd failed!";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG_INFO << "End finalize env";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::shared_ptr<FuncGraph> MsModel::LoadModel(const char *model_buf, size_t size, const std::string &device) {
|
||||
MS_EXCEPTION_IF_NULL(model_buf);
|
||||
try {
|
||||
auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size);
|
||||
return anf_graph;
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Inference LoadModel failed: " << e.what();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void MsModel::RegAllOp() {
|
||||
static std::mutex init_mutex;
|
||||
static bool Initialized = false;
|
||||
|
||||
std::lock_guard<std::mutex> lock(init_mutex);
|
||||
if (Initialized) {
|
||||
return;
|
||||
}
|
||||
Initialized = true;
|
||||
auto ms_context_instance = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context_instance);
|
||||
ms_context_instance->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
try {
|
||||
std::shared_ptr<py::scoped_interpreter> guard;
|
||||
if (Py_IsInitialized() == 0) {
|
||||
guard = std::make_shared<py::scoped_interpreter>();
|
||||
}
|
||||
py::module c_expression = py::module::import("mindspore._c_expression");
|
||||
size_t ops_info_long = c_expression.attr("OpInfoLoaderPy")().attr("get_all_ops_info")().cast<size_t>();
|
||||
auto all_ops_info = reinterpret_cast<std::vector<kernel::OpInfo *> *>(ops_info_long);
|
||||
for (auto op_info : *all_ops_info) {
|
||||
kernel::OpLib::RegOpInfo(std::shared_ptr<kernel::OpInfo>(op_info));
|
||||
}
|
||||
all_ops_info->clear();
|
||||
delete all_ops_info;
|
||||
} catch (const std::runtime_error &ex) {
|
||||
MS_LOG_EXCEPTION << ex.what();
|
||||
}
|
||||
}
|
||||
|
||||
Status MsModel::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) {
|
||||
MS_ASSERT(session_impl_ != nullptr);
|
||||
try {
|
||||
graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
|
||||
py::gil_scoped_release gil_release;
|
||||
return SUCCESS;
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Inference CompileGraph failed: " << e.what();
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<tensor::TensorPtr> MsModel::RunGraph(const std::vector<tensor::TensorPtr> &inputs) {
|
||||
try {
|
||||
VectorRef outputs;
|
||||
session_impl_->RunGraph(graph_id_, inputs, &outputs);
|
||||
return TransformVectorRefToMultiTensor(outputs);
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Inference Rungraph failed: " << e.what();
|
||||
return std::vector<tensor::TensorPtr>();
|
||||
}
|
||||
}
|
||||
|
||||
Status MsModel::InitEnv(const std::unordered_map<std::string, std::string> &other_options) {
|
||||
RegAllOp();
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
if (ms_context == nullptr) {
|
||||
MS_LOG(ERROR) << "Get Context failed!";
|
||||
return FAILED;
|
||||
}
|
||||
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
|
||||
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice);
|
||||
if (!context::OpenTsd(ms_context)) {
|
||||
MS_LOG(ERROR) << "Session init OpenTsd failed!";
|
||||
return FAILED;
|
||||
}
|
||||
session_impl_ = session::SessionFactory::Get().Create(kDavinciInferenceDevice);
|
||||
if (session_impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kDavinciInferenceDevice
|
||||
<< " is available.";
|
||||
return FAILED;
|
||||
}
|
||||
session_impl_->Init(device_id_);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsModel::CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const {
|
||||
MS_ASSERT(session_impl_ != nullptr);
|
||||
std::string error_msg;
|
||||
if (!session_impl_->CheckModelInputs(graph_id, inputs, &error_msg)) {
|
||||
return Status(INVALID_INPUTS, error_msg);
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsModel::GetInputsInfo(std::vector<Tensor> *tensor_list) const {
|
||||
MS_EXCEPTION_IF_NULL(tensor_list);
|
||||
tensor_list->clear();
|
||||
for (size_t i = 0; i < inputs_.size(); i++) {
|
||||
auto &tensor = inputs_[i];
|
||||
Tensor infer_tensor;
|
||||
infer_tensor.SetName(input_names_[i]);
|
||||
infer_tensor.SetDataType(TransTypeId2InferDataType(tensor->data_type()));
|
||||
infer_tensor.SetShape(tensor->shape());
|
||||
tensor_list->push_back(infer_tensor);
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MsModel::GetOutputsInfo(std::vector<Tensor> *tensor_list) const {
|
||||
MS_EXCEPTION_IF_NULL(tensor_list);
|
||||
tensor_list->clear();
|
||||
for (size_t i = 0; i < outputs_.size(); i++) {
|
||||
auto &tensor = outputs_[i];
|
||||
Tensor infer_tensor;
|
||||
infer_tensor.SetName(output_names_[i]);
|
||||
infer_tensor.SetDataType(TransTypeId2InferDataType(tensor->data_type()));
|
||||
infer_tensor.SetShape(tensor->shape());
|
||||
tensor_list->push_back(infer_tensor);
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_SESSION_SESSION_H
|
||||
#define MINDSPORE_CCSRC_SESSION_SESSION_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "ir/anf.h"
|
||||
#include "include/api/status.h"
|
||||
#include "cxx_api/model/model_impl.h"
|
||||
|
||||
#ifdef ENABLE_D
|
||||
#include "runtime/context.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
class MsModel : public ModelImpl {
|
||||
public:
|
||||
explicit MsModel(uint32_t device_id);
|
||||
~MsModel();
|
||||
|
||||
Status LoadModel(const Buffer &model_data, ModelType type,
|
||||
const std::map<std::string, std::string> &options) override;
|
||||
Status LoadModel(const std::string &file_name, ModelType type,
|
||||
const std::map<std::string, std::string> &options) override;
|
||||
Status UnloadModel() override;
|
||||
|
||||
Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
|
||||
Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
|
||||
Status Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) override;
|
||||
|
||||
Status GetInputsInfo(std::vector<Tensor> *tensor_list) const override;
|
||||
Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const override;
|
||||
|
||||
Status InitEnv(const std::unordered_map<std::string, std::string> &other_options);
|
||||
Status FinalizeEnv();
|
||||
|
||||
private:
|
||||
std::shared_ptr<session::SessionBasic> session_impl_ = nullptr;
|
||||
uint32_t graph_id_;
|
||||
std::string device_type_;
|
||||
int32_t device_id_ = 0;
|
||||
#ifdef ENABLE_D
|
||||
rtContext_t context_ = nullptr;
|
||||
#endif
|
||||
std::vector<tensor::TensorPtr> inputs_;
|
||||
std::vector<tensor::TensorPtr> outputs_;
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<std::string> output_names_;
|
||||
bool load_flag_ = false;
|
||||
|
||||
std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const std::string &device);
|
||||
Buffer ReadFile(const std::string &file);
|
||||
static void RegAllOp();
|
||||
Status CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr);
|
||||
Status CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const;
|
||||
std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs);
|
||||
Status ExecuteModel(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
|
||||
};
|
||||
|
||||
API_REG_MODEL(AscendMS, MsModel);
|
||||
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H
|
Loading…
Reference in New Issue