support mindpsore::api:Model 910 inference, support 310 model convert in python env

This commit is contained in:
xuyongfei 2020-11-23 17:07:53 +08:00
parent 452cb0dd4e
commit a3b9218919
18 changed files with 1123 additions and 40 deletions

View File

@ -50,9 +50,15 @@ class MS_API Model {
Status GetInputsInfo(std::vector<Tensor> *tensor_list) const;
Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const;
static bool CheckModelSupport(const std::string& device_type, ModelType model_type);
private:
std::shared_ptr<ModelImpl> impl_;
};
extern MS_API const char* kDeviceTypeAscendCL;
extern MS_API const char* kDeviceTypeAscendMS;
} // namespace api
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_H

View File

@ -213,30 +213,5 @@ std::string AscendInferenceSession::InputsInfo(const std::vector<ParameterPtr> &
return graph + " " + actual;
}
void AscendInferenceSession::GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs) const {
MS_LOG(INFO) << "Start get model inputs, graph id : " << graph_id;
auto kernel_graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(kernel_graph);
auto kernel_graph_inputs = kernel_graph->inputs();
vector<ParameterPtr> paras;
// find parameters of graph inputs
for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) {
if (!kernel_graph_inputs[i]->isa<Parameter>()) {
MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter.";
continue;
}
auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>();
if (!AnfAlgo::IsParameterWeight(parameter)) {
vector<int64_t> input_shape;
auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0);
(void)std::transform(parameter_shape.begin(), parameter_shape.end(), std::back_inserter(input_shape),
[](const size_t dim) { return SizeToLong(dim); });
auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter);
auto data_type = kernel_build_info->GetOutputDeviceType(0);
auto ms_tensor = std::make_shared<tensor::Tensor>(data_type, input_shape);
inputs->push_back(ms_tensor);
}
}
}
} // namespace session
} // namespace mindspore

View File

@ -44,7 +44,6 @@ class AscendInferenceSession : public AscendSession {
template <typename T>
std::string PrintInputShape(std::vector<T> shape) const;
std::string InputsInfo(const std::vector<ParameterPtr> &paras, const std::vector<tensor::TensorPtr> &inputs) const;
void GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs) const override;
protected:
GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override;

View File

@ -370,7 +370,8 @@ Status MSInferSession::CheckModelInputs(uint32_t graph_id, const std::vector<ten
Status MSInferSession::GetModelInputsInfo(uint32_t model_id, std::vector<inference::InferTensor> *tensor_list) const {
vector<tensor::TensorPtr> inputs;
session_impl_->GetModelInputsInfo(model_id, &inputs);
vector<std::string> input_names;
session_impl_->GetModelInputsInfo(model_id, &inputs, &input_names);
if (inputs.size() == 0) {
MS_LOG(ERROR) << "The model inputs is NULL";
return FAILED;

View File

@ -34,6 +34,8 @@
#include "ir/func_graph_cloner.h"
#include "utils/utils.h"
#include "debug/anf_ir_dump.h"
#include "mindspore/core/base/base_ref_utils.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/worker.h"
#include "ps/util.h"
@ -1089,6 +1091,61 @@ void SessionBasic::CreateOutputTensors(const GraphId &graph_id, const std::vecto
}
}
void SessionBasic::GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs,
std::vector<std::string> *inputs_name) const {
MS_LOG(INFO) << "Start get model inputs, graph id : " << graph_id;
auto kernel_graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(inputs);
MS_EXCEPTION_IF_NULL(inputs_name);
auto kernel_graph_inputs = kernel_graph->inputs();
vector<ParameterPtr> paras;
// find parameters of graph inputs
for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) {
if (!kernel_graph_inputs[i]->isa<Parameter>()) {
MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter.";
continue;
}
auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>();
if (!AnfAlgo::IsParameterWeight(parameter)) {
vector<int64_t> input_shape;
auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0);
(void)std::transform(parameter_shape.begin(), parameter_shape.end(), std::back_inserter(input_shape),
[](const size_t dim) { return SizeToLong(dim); });
auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter);
auto data_type = kernel_build_info->GetOutputDeviceType(0);
auto ms_tensor = std::make_shared<tensor::Tensor>(data_type, input_shape);
inputs->push_back(ms_tensor);
inputs_name->push_back(parameter->name());
}
}
}
void SessionBasic::GetModelOutputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *outputs,
std::vector<std::string> *output_names) const {
std::vector<tensor::TensorPtr> inputs;
std::vector<std::string> input_names;
GetModelInputsInfo(graph_id, &inputs, &input_names);
auto kernel_graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(outputs);
MS_EXCEPTION_IF_NULL(output_names);
VectorRef vector_outputs;
std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node;
auto anf_outputs = kernel_graph->outputs();
for (auto &item : anf_outputs) {
MS_EXCEPTION_IF_NULL(item);
MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]";
vector_outputs.emplace_back(CreateNodeOutputTensors(item, kernel_graph, inputs, &tensor_to_node));
}
*outputs = TransformVectorRefToMultiTensor(vector_outputs);
for (size_t i = 0; i < outputs->size(); i++) {
output_names->push_back("output" + std::to_string(i));
}
}
void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
MS_EXCEPTION_IF_NULL(callback);
summary_callback_ = callback;

View File

@ -102,7 +102,10 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
std::string *error_msg) const {
return true;
}
virtual void GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs) const {}
void GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs,
std::vector<std::string> *inputs_name) const;
void GetModelOutputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *outputs,
std::vector<std::string> *outputs_name) const;
std::vector<tensor::TensorPtr> GetInputNeedLockTensors(const GraphId &graph_id,
const std::vector<tensor::TensorPtr> &inputs);
// Get graph by graph id, if not exist return null ptr

View File

@ -6,22 +6,25 @@ set(LOAD_ONNX_SRC
file(GLOB_RECURSE API_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR} "ops/*.cc")
if (ENABLE_ACL)
file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/acl/*.cc")
file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/acl/*.cc" "model/model_converter_utils/*.cc")
elseif (ENABLE_D)
file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/ms/*.cc")
endif ()
set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc
${CMAKE_CURRENT_SOURCE_DIR}/cell.cc
${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc
${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc
${API_ACL_SRC}
${API_OPS_SRC}
${LOAD_ONNX_SRC})
${CMAKE_CURRENT_SOURCE_DIR}/cell.cc
${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc
${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc
${API_MS_INFER_SRC}
${API_ACL_SRC}
${API_OPS_SRC}
${LOAD_ONNX_SRC})
add_library(mindspore_shared_lib SHARED ${MSLIB_SRC})
set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore PUBLIC_HEADER "${API_INCLUDE}")
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
-Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_gvar mindspore::protobuf)
-Wl,--whole-archive mindspore -Wl,--no-whole-archive proto_input mindspore_gvar mindspore::protobuf)
if (ENABLE_CPU)
target_link_libraries(mindspore_shared_lib PRIVATE mindspore::dnnl mindspore::mkldnn)
@ -58,5 +61,13 @@ if (ENABLE_ACL)
find_library(acl_runtime libruntime.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
find_library(ge_compiler libge_compiler.so ${ATC_DIR}/lib64 ${ATLAS_ATC_DIR}/lib64)
target_link_libraries(mindspore_shared_lib PRIVATE ${acl} ${acl_retr} ${acl_cblas} ${acl_dvpp} ${acl_runtime}
${ge_compiler} mindspore::jpeg_turbo)
${ge_compiler} mindspore::jpeg_turbo)
endif ()
# Before build inference
if (ENABLE_D)
find_library(adump_server libadump_server.a ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
target_link_libraries(mindspore_shared_lib PRIVATE ${adump_server})
endif ()

View File

@ -91,7 +91,6 @@ Status AclModel::InitEnv() {
MS_LOG(ERROR) << "DVPP init resource failed";
return FAILED;
}
ModelConverter::RegAllOp();
MS_LOG(INFO) << "Init acl success, device id " << device_id_;
init_flag_ = true;

View File

@ -24,6 +24,7 @@
#include "backend/kernel_compiler/oplib/oplib.h"
#include "graph/model.h"
#include "cxx_api/model/model_converter_utils/multi_process.h"
namespace py = pybind11;
@ -238,6 +239,131 @@ Buffer ModelConverter::ReadFile(const std::string &file) {
}
Buffer ModelConverter::LoadMindIR(const Buffer &model_data) {
if (Py_IsInitialized() == 0) {
MS_LOG_INFO << "Call LoadMindIRInner directly";
return LoadMindIRInner(model_data);
}
MultiProcess multi_process;
Buffer buffer_ret;
auto parent_process = [&model_data, &buffer_ret](MultiProcess *multi_process) -> Status {
MS_EXCEPTION_IF_NULL(multi_process);
// send original model to child
auto status = multi_process->SendMsg(model_data.Data(), model_data.DataSize());
if (!status.IsSuccess()) {
MS_LOG_ERROR << "Send original model to child process failed";
return FAILED;
}
// receive convert model result from child
CreateBufferCall call = [&buffer_ret](size_t msg_len) -> uint8_t * {
buffer_ret.ResizeData(msg_len);
return reinterpret_cast<uint8_t *>(buffer_ret.MutableData());
};
status = multi_process->ReceiveMsg(call);
if (!status.IsSuccess()) {
MS_LOG_ERROR << "Receive result model from child process failed";
return FAILED;
}
return SUCCESS;
};
auto child_process = [this](MultiProcess *multi_process) -> Status {
MS_EXCEPTION_IF_NULL(multi_process);
// receive original model from parent
Buffer model;
CreateBufferCall call = [&model](size_t msg_len) -> uint8_t * {
model.ResizeData(msg_len);
return reinterpret_cast<uint8_t *>(model.MutableData());
};
auto status = multi_process->ReceiveMsg(call);
if (!status.IsSuccess()) {
MS_LOG_ERROR << "Receive original model from parent process failed";
return FAILED;
}
Buffer model_result = LoadMindIRInner(model);
if (model_result.DataSize() == 0) {
MS_LOG_ERROR << "Convert model from MindIR to OM failed";
return FAILED;
}
// send result model to parent
status = multi_process->SendMsg(model_result.Data(), model_result.DataSize());
if (!status.IsSuccess()) {
MS_LOG_ERROR << "Send result model to parent process failed";
return FAILED;
}
return SUCCESS;
};
auto status = multi_process.MainProcess(parent_process, child_process);
if (!status.IsSuccess()) {
MS_LOG_ERROR << "Convert MindIR model to OM model failed";
} else {
MS_LOG_INFO << "Convert MindIR model to OM model success";
}
return buffer_ret;
}
Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) {
if (Py_IsInitialized() == 0) {
MS_LOG_INFO << "Call LoadAscendIRInner directly";
return LoadAscendIRInner(model_data);
}
MultiProcess multi_process;
Buffer buffer_ret;
auto parent_process = [&model_data, &buffer_ret](MultiProcess *multi_process) -> Status {
MS_EXCEPTION_IF_NULL(multi_process);
// send original model to child
auto status = multi_process->SendMsg(model_data.Data(), model_data.DataSize());
if (!status.IsSuccess()) {
MS_LOG_ERROR << "Send original model to child process failed";
return FAILED;
}
// receive convert model result from child
CreateBufferCall call = [&buffer_ret](size_t msg_len) -> uint8_t * {
buffer_ret.ResizeData(msg_len);
return reinterpret_cast<uint8_t *>(buffer_ret.MutableData());
};
status = multi_process->ReceiveMsg(call);
if (!status.IsSuccess()) {
MS_LOG_ERROR << "Receive result model from child process failed";
return FAILED;
}
return SUCCESS;
};
auto child_process = [this](MultiProcess *multi_process) -> Status {
MS_EXCEPTION_IF_NULL(multi_process);
// receive original model from parent
Buffer model;
CreateBufferCall call = [&model](size_t msg_len) -> uint8_t * {
model.ResizeData(msg_len);
return reinterpret_cast<uint8_t *>(model.MutableData());
};
auto status = multi_process->ReceiveMsg(call);
if (!status.IsSuccess()) {
MS_LOG_ERROR << "Receive original model from parent process failed";
return FAILED;
}
Buffer model_result = LoadAscendIRInner(model);
if (model_result.DataSize() == 0) {
MS_LOG_ERROR << "Convert model from AIR to OM failed";
return FAILED;
}
// send result model to parent
status = multi_process->SendMsg(model_result.Data(), model_result.DataSize());
if (!status.IsSuccess()) {
MS_LOG_ERROR << "Send result model to parent process failed";
return FAILED;
}
return SUCCESS;
};
auto status = multi_process.MainProcess(parent_process, child_process);
if (!status.IsSuccess()) {
MS_LOG_ERROR << "Convert AIR model to OM model failed";
} else {
MS_LOG_INFO << "Convert AIR model to OM model success";
}
return buffer_ret;
}
Buffer ModelConverter::LoadMindIRInner(const Buffer &model_data) {
RegAllOp();
auto func_graph = ConvertMindIrToFuncGraph(model_data);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Convert MindIR to FuncGraph failed.";
@ -259,7 +385,8 @@ Buffer ModelConverter::LoadMindIR(const Buffer &model_data) {
return om_data;
}
Buffer ModelConverter::LoadAscendIR(const Buffer &model_data) {
Buffer ModelConverter::LoadAscendIRInner(const Buffer &model_data) {
RegAllOp();
ge::Model load_model = ge::Model("loadmodel", "version2");
ge::Status ret =
ge::Model::Load(reinterpret_cast<const uint8_t *>(model_data.Data()), model_data.DataSize(), load_model);

View File

@ -45,6 +45,9 @@ class ModelConverter {
transform::DfGraphPtr ConvertFuncGraphToAIR(const FuncGraphPtr &anf_graph);
Buffer BuildAirModel(const transform::DfGraphPtr &graph, const std::map<std::string, std::string> &acl_options);
AclModelOptions *options_;
Buffer LoadMindIRInner(const Buffer &model_data);
Buffer LoadAscendIRInner(const Buffer &model_data);
};
} // namespace mindspore::api

View File

@ -18,6 +18,9 @@
#include "utils/utils.h"
namespace mindspore::api {
const char *kDeviceTypeAscendCL = "AscendCL";
const char *kDeviceTypeAscendMS = "AscendMS";
Status Model::LoadModel(const Buffer &model_data, ModelType type, const std::map<std::string, std::string> &options) {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->LoadModel(model_data, type, options);
@ -95,4 +98,9 @@ Model::Model(NetWork network, const std::string &device_type, uint32_t device_id
}
Model::~Model() {}
bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) {
return ModelFactory::Instance().CheckModelSupport(device_type, model_type);
}
} // namespace mindspore::api

View File

@ -0,0 +1,207 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cxx_api/model/model_converter_utils/multi_process.h"
#include <unistd.h>
#include <sys/wait.h>
#include <algorithm>
#include <vector>
#include <thread>
#include "mindspore/core/utils/log_adapter.h"
#include "cxx_api/model/model_converter_utils/shared_memory.h"
namespace mindspore {
namespace api {
namespace {
uint64_t kSharedMemorySize = 100ull << 20; // 100 MB
}
MultiProcess::MultiProcess() = default;
MultiProcess::~MultiProcess() = default;
Status MultiProcess::MainProcess(ProcessFuncCall parent_process, ProcessFuncCall child_process) {
MS_EXCEPTION_IF_NULL(parent_process);
MS_EXCEPTION_IF_NULL(child_process);
Status ret;
memory_size_ = kSharedMemorySize; // 100 MB
SharedMemory shared_memory;
ret = shared_memory.Create(memory_size_);
if (!ret.IsSuccess()) {
MS_LOG_ERROR << "Create shared memory failed";
return ret;
}
pid_t pid = fork();
if (pid < 0) {
shared_memory.Destroy();
MS_LOG_ERROR << "Fork process to convert model failed";
return FAILED;
}
ret = shared_memory.Attach();
if (!ret.IsSuccess()) {
MS_LOG_ERROR << "Process attach shared memory failed, pid " << pid;
return ret;
}
shmat_addr_ = shared_memory.GetSharedMemoryAddr();
if (shmat_addr_ == nullptr) {
MS_LOG_ERROR << "Get shared memory failed";
return ret;
}
shmat_data_addr_ = shmat_addr_ + sizeof(MessageFlag) * 2;
shmat_data_max_size_ = memory_size_ - (shmat_data_addr_ - shmat_addr_);
MS_LOG_INFO << "Shm addr " << (uint64_t)shmat_addr_;
if (pid == 0) {
ChildProcess(child_process);
shared_memory.Detach();
MS_LOG_INFO << "Model converter: child process exit";
exit(0);
} else { // parent process
ret = ParentProcess(parent_process);
shared_memory.Detach();
int status;
wait(&status);
shared_memory.Destroy();
}
return ret;
}
Status MultiProcess::ParentProcess(ProcessFuncCall parent_process) {
auto parent_msg = reinterpret_cast<MessageFlag *>(shmat_addr_);
auto child_msg = reinterpret_cast<MessageFlag *>(shmat_addr_ + sizeof(MessageFlag));
send_msg_ = parent_msg;
receive_msg_ = child_msg;
std::thread heartbeat_thread(MultiProcess::HeartbeatThreadFunc, this);
Status ret;
try {
ret = parent_process(this);
if (!ret.IsSuccess()) {
MS_LOG_ERROR << "Parent process process failed";
}
} catch (const std::runtime_error &ex) {
MS_LOG_ERROR << "Catch parent process runtime error: " << ex.what();
ret = FAILED;
}
stopped_ = true;
send_msg_->stop = true;
heartbeat_thread.join();
return ret;
}
void MultiProcess::ChildProcess(ProcessFuncCall child_process) {
auto parent_msg = reinterpret_cast<MessageFlag *>(shmat_addr_);
auto child_msg = reinterpret_cast<MessageFlag *>(shmat_addr_ + sizeof(MessageFlag));
send_msg_ = child_msg;
receive_msg_ = parent_msg;
std::thread heartbeat_thread(MultiProcess::HeartbeatThreadFunc, this);
try {
auto ret = child_process(this);
if (!ret.IsSuccess()) {
MS_LOG_ERROR << "Child process process failed";
}
} catch (const std::runtime_error &ex) {
MS_LOG_ERROR << "Catch child process runtime error: " << ex.what();
}
stopped_ = true;
send_msg_->stop = true;
heartbeat_thread.join();
}
Status MultiProcess::SendMsg(const void *buffer, uint64_t msg_len) {
MS_LOG_INFO << "Start to send message to peer process, msg len " << msg_len;
send_msg_->msg_total_len = msg_len;
uint64_t cur_offset = 0;
while (msg_len > cur_offset) {
uint64_t sub_msg_len = std::min(msg_len - cur_offset, shmat_data_max_size_);
memcpy_s(shmat_data_addr_, shmat_data_max_size_, static_cast<const uint8_t *>(buffer) + cur_offset, sub_msg_len);
cur_offset += sub_msg_len;
send_msg_->msg_len = sub_msg_len;
send_msg_->read_finish_flag = false;
send_msg_->read_ready_flag = true;
MS_LOG_INFO << "Send start " << cur_offset << ", msg len " << sub_msg_len << ", total len " << msg_len;
while (!send_msg_->read_finish_flag && !peer_stopped_) {
usleep(1000); // 1ms
}
if (peer_stopped_) {
if (!send_msg_->read_finish_flag) {
return FAILED;
}
break;
}
MS_LOG_INFO << "Send end " << cur_offset << ", msg len " << sub_msg_len << ", total len " << msg_len;
}
MS_LOG_INFO << "End to send message to peer process, msg len " << msg_len;
return SUCCESS;
}
Status MultiProcess::ReceiveMsg(CreateBufferCall create_buffer_call) {
uint64_t cur_offset = 0;
uint8_t *msg_buffer = nullptr;
uint64_t msg_len = 0;
do {
MS_LOG_INFO << "Receive start from " << cur_offset;
while (!receive_msg_->read_ready_flag && !peer_stopped_) {
usleep(1000); // 1ms
}
if (peer_stopped_) {
return FAILED;
}
if (msg_buffer == nullptr) {
msg_len = receive_msg_->msg_total_len;
msg_buffer = create_buffer_call(msg_len);
}
memcpy_s(msg_buffer + cur_offset, msg_len - cur_offset, shmat_data_addr_, receive_msg_->msg_len);
cur_offset += receive_msg_->msg_len;
receive_msg_->read_ready_flag = false;
receive_msg_->read_finish_flag = true;
MS_LOG_INFO << "Receive end, current length " << cur_offset << ", total length " << msg_len << std::endl;
} while (msg_len > cur_offset);
return SUCCESS;
}
void MultiProcess::HeartbeatThreadFunc(MultiProcess *multi_process) { multi_process->HeartbeatThreadFuncInner(); }
void MultiProcess::HeartbeatThreadFuncInner() {
uint64_t last_beat_cnt = 0;
uint64_t repeat_cnt = 0;
while (!stopped_) {
if (receive_msg_->stop) {
peer_stopped_ = true;
MS_LOG_WARNING << "Peer stopped";
break;
}
uint64_t heartbeat_gap = receive_msg_->heartbeat - last_beat_cnt;
if (heartbeat_gap > 0 && heartbeat_gap < 1024) {
last_beat_cnt = receive_msg_->heartbeat;
repeat_cnt = 0;
} else {
repeat_cnt++;
if (repeat_cnt > 30) { // 30*100ms = 3s no reply
peer_stopped_ = true;
MS_LOG_WARNING << "Peer stopped";
break;
}
}
send_msg_->heartbeat += 1;
usleep(100000); // sleep 100 ms
}
}
} // namespace api
} // namespace mindspore

View File

@ -0,0 +1,68 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H
#define MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H
#include <iostream>
#include <functional>
#include "include/api/status.h"
namespace mindspore {
namespace api {
struct MessageFlag {
uint64_t heartbeat = 0;
uint64_t stop = false;
uint64_t msg_len = 0;
uint64_t msg_total_len = 0;
uint64_t read_ready_flag = false;
uint64_t read_finish_flag = false;
};
class MultiProcess;
using ProcessFuncCall = std::function<Status(MultiProcess *multi_process)>;
using CreateBufferCall = std::function<uint8_t *(size_t msg_len)>;
class MultiProcess {
public:
MultiProcess();
~MultiProcess();
Status MainProcess(ProcessFuncCall parent_process, ProcessFuncCall child_process);
Status SendMsg(const void *buffer, uint64_t msg_len);
Status ReceiveMsg(CreateBufferCall create_buffer_call);
private:
uint8_t *shmat_addr_ = nullptr;
uint8_t *shmat_data_addr_ = nullptr;
uint64_t shmat_data_max_size_ = 0;
uint64_t memory_size_ = 0;
bool peer_stopped_ = false;
bool stopped_ = false;
MessageFlag *send_msg_ = nullptr;
MessageFlag *receive_msg_ = nullptr;
static void HeartbeatThreadFunc(MultiProcess *multi_process);
void HeartbeatThreadFuncInner();
Status ParentProcess(ProcessFuncCall parent_process);
void ChildProcess(ProcessFuncCall child_process);
};
} // namespace api
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H

View File

@ -0,0 +1,69 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cxx_api/model/model_converter_utils/shared_memory.h"
#include <sys/shm.h>
#include <sys/stat.h>
#include <string>
#include "mindspore/core/utils/log_adapter.h"
namespace mindspore {
namespace api {
Status SharedMemory::Create(uint64_t memory_size) {
auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP;
shm_id_ = shmget(IPC_PRIVATE, memory_size, IPC_CREAT | IPC_EXCL | access_mode);
if (shm_id_ == -1) {
MS_LOG_ERROR << "Shared memory creation failed. Errno " + std::to_string(errno);
return FAILED;
}
MS_LOG_INFO << "shmget success, shm id " << shm_id_;
return SUCCESS;
}
Status SharedMemory::Attach() {
void *shmat_addr = shmat(shm_id_, nullptr, 0);
if (shmat_addr == reinterpret_cast<void *>(-1)) {
MS_LOG_ERROR << "Shared memory attach failed. Errno " + std::to_string(errno);
return FAILED;
}
shmat_addr_ = reinterpret_cast<uint8_t *>(shmat_addr);
return SUCCESS;
}
void SharedMemory::Detach() {
if (shmat_addr_) {
auto err = shmdt(shmat_addr_);
if (err == -1) {
MS_LOG_ERROR << "Shared memory detach failed. Errno " + std::to_string(errno);
return;
}
}
shmat_addr_ = nullptr;
}
void SharedMemory::Destroy() {
// Remove the shared memory and never mind about the return code.
auto err = shmctl(shm_id_, IPC_RMID, nullptr);
if (err == -1) {
std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id_);
errMsg += ". Errno :" + std::to_string(errno);
errMsg += "\nPlesae remove it manually using ipcrm -m command";
MS_LOG_ERROR << errMsg;
}
}
} // namespace api
} // namespace mindspore

View File

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H
#define MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H
#include <iostream>
#include "include/api/status.h"
namespace mindspore {
namespace api {
class SharedMemory {
public:
Status Create(uint64_t memory_size);
Status Attach();
void Detach();
void Destroy();
uint8_t *GetSharedMemoryAddr() { return shmat_addr_; }
private:
int shm_id_ = -1;
uint8_t *shmat_addr_ = nullptr;
};
} // namespace api
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H

View File

@ -70,6 +70,12 @@ class ModelFactory {
return nullptr;
}
bool CheckModelSupport(const std::string &device_type, ModelType /*model_type*/) {
return std::any_of(
model_creators_.begin(), model_creators_.end(),
[&device_type](const std::pair<std::string, ModelCreator> &item) { return item.first == device_type; });
}
private:
ModelFactory() = default;
~ModelFactory() = default;
@ -86,7 +92,7 @@ class ModelRegistrar {
#define API_REG_MODEL(DEVICE_NAME, MODEL_CLASS) \
static const ModelRegistrar g_api_model_registrar__##DEVICE_NAME##_##_reg( \
#DEVICE_NAME, [](uint32_t device_id) { return std::make_shared<MODEL_CLASS>(device_id); });
kDeviceType##DEVICE_NAME, [](uint32_t device_id) { return std::make_shared<MODEL_CLASS>(device_id); });
} // namespace mindspore::api

View File

@ -0,0 +1,418 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cxx_api/model/ms/ms_model.h"
#include <memory>
#include <algorithm>
#include <fstream>
#include "utils/load_onnx/anf_converter.h"
#include "backend/session/session_basic.h"
#include "backend/session/session_factory.h"
#include "backend/session/executor_manager.h"
#include "base/base_ref_utils.h"
#include "backend/kernel_compiler/oplib/oplib.h"
#include "utils/context/context_extends.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "pybind11/pybind11.h"
#include "pybind11/embed.h"
#ifdef ENABLE_D
#include "utils/ms_context.h"
#endif
using std::string;
using std::vector;
namespace py = pybind11;
namespace mindspore {
namespace api {
MsModel::MsModel(uint32_t device_id) : device_id_(device_id) {}
MsModel::~MsModel() = default;
TypeId TransInferDataType2TypeId(DataType data_type) {
const std::map<api::DataType, TypeId> type2id_map{
{api::kMsUnknown, TypeId::kNumberTypeBegin}, {api::kMsBool, TypeId::kNumberTypeBool},
{api::kMsInt8, TypeId::kNumberTypeInt8}, {api::kMsUint8, TypeId::kNumberTypeUInt8},
{api::kMsInt16, TypeId::kNumberTypeInt16}, {api::kMsUint16, TypeId::kNumberTypeUInt16},
{api::kMsInt32, TypeId::kNumberTypeInt32}, {api::kMsUint32, TypeId::kNumberTypeUInt32},
{api::kMsInt64, TypeId::kNumberTypeInt64}, {api::kMsUint64, TypeId::kNumberTypeUInt64},
{api::kMsFloat16, TypeId::kNumberTypeFloat16}, {api::kMsFloat32, TypeId::kNumberTypeFloat32},
{api::kMsFloat64, TypeId::kNumberTypeFloat64},
};
auto it = type2id_map.find(data_type);
if (it == type2id_map.end()) {
MS_LOG_WARNING << "Unsupported MSI data type " << data_type;
return TypeId::kNumberTypeBegin;
} else {
return it->second;
}
}
DataType TransTypeId2InferDataType(TypeId type_id) {
const std::map<TypeId, api::DataType> id2type_map{
{TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool},
{TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8},
{TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16},
{TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32},
{TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64},
{TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16},
{TypeId::kNumberTypeFloat32, api::kMsFloat32},
};
auto it = id2type_map.find(type_id);
if (it == id2type_map.end()) {
MS_LOG_WARNING << "Unsupported data id " << type_id;
return api::kMsUnknown;
} else {
return it->second;
}
}
Buffer MsModel::ReadFile(const std::string &file) {
if (file.empty()) {
MS_LOG(ERROR) << "file is nullptr";
return Buffer();
}
std::ifstream ifs(file);
if (!ifs.good()) {
MS_LOG(ERROR) << "file: " << file << " is not exist";
return Buffer();
}
if (!ifs.is_open()) {
MS_LOG(ERROR) << "file: " << file << "open failed";
return Buffer();
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
Buffer buffer;
buffer.ResizeData(size);
ifs.seekg(0, std::ios::beg);
ifs.read(static_cast<char *>(buffer.MutableData()), size);
ifs.close();
return buffer;
}
Status MsModel::LoadModel(const Buffer &model_data, ModelType type, const std::map<std::string, std::string> &options) {
auto status = InitEnv({});
if (status != SUCCESS) {
MS_LOG(ERROR) << "Init env failed";
return FAILED;
}
std::shared_ptr<FuncGraph> anf_graph;
try {
anf_graph =
lite::AnfConverter::RunAnfConverter(static_cast<const char *>(model_data.Data()), model_data.DataSize());
} catch (std::exception &e) {
MS_LOG(ERROR) << "Inference LoadModel failed";
return FAILED;
}
Status ret = CompileGraph(anf_graph);
if (ret != SUCCESS) {
MS_LOG(ERROR) << "Compile graph model failed";
return FAILED;
}
session_impl_->GetModelInputsInfo(graph_id_, &inputs_, &input_names_);
session_impl_->GetModelOutputsInfo(graph_id_, &outputs_, &output_names_);
if (inputs_.empty() || inputs_.size() != input_names_.size()) {
MS_LOG_ERROR << "Get model inputs info failed";
return FAILED;
}
if (outputs_.empty() || outputs_.size() != output_names_.size()) {
MS_LOG_ERROR << "Get model outputs info failed";
return FAILED;
}
MS_LOG(INFO) << "Load model success";
#ifdef ENABLE_D
// set d context
rtError_t rt_ret = rtCtxGetCurrent(&context_);
if (rt_ret != RT_ERROR_NONE || context_ == nullptr) {
MS_LOG(ERROR) << "the ascend device context is null";
return FAILED;
}
#endif
load_flag_ = true;
return SUCCESS;
}
Status MsModel::LoadModel(const std::string &file_name, ModelType type,
const std::map<std::string, std::string> &options) {
auto graphBuf = ReadFile(file_name);
if (graphBuf.DataSize() == 0) {
MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str();
return FAILED;
}
auto status = LoadModel(graphBuf, type, options);
if (status != SUCCESS) {
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
return FAILED;
}
return SUCCESS;
}
Status MsModel::UnloadModel() {
if (!load_flag_) {
MS_LOG_ERROR << "Model has not been loaded";
return FAILED;
}
FinalizeEnv();
load_flag_ = false;
return SUCCESS;
}
Status MsModel::Train(const DataSet &, std::map<std::string, Buffer> *) {
MS_LOG(ERROR) << "Unsupported feature.";
return FAILED;
}
Status MsModel::Eval(const DataSet &, std::map<std::string, Buffer> *) {
MS_LOG(ERROR) << "Unsupported feature.";
return FAILED;
}
Status MsModel::Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
if (!load_flag_) {
MS_LOG(ERROR) << "No model is loaded, predict failed.";
return FAILED;
}
if (inputs.size() != inputs_.size()) {
MS_LOG(ERROR) << "inputs count not match, required count " << inputs_.size() << ", given count " << inputs.size();
return INVALID_INPUTS;
}
std::vector<Buffer> request;
std::vector<Buffer> reply;
for (size_t i = 0; i < inputs_.size(); ++i) {
const auto &input_name = input_names_[i];
auto iter = inputs.find(input_name);
if (iter == inputs.end()) {
MS_LOG(ERROR) << "Model missing input " << input_name;
return INVALID_INPUTS;
}
if (iter->second.DataSize() != inputs_[i]->Size()) {
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_[i]->Size() << ", given count "
<< iter->second.DataSize();
return INVALID_INPUTS;
}
request.push_back(iter->second);
}
if (ExecuteModel(request, &reply) != SUCCESS) {
MS_LOG(ERROR) << "Execute Model Failed";
return FAILED;
}
if (outputs_.size() != reply.size()) {
MS_LOG(ERROR) << "Predict output size " << reply.size() << " not match output size got from model info "
<< outputs_.size();
return FAILED;
}
outputs->clear();
for (size_t i = 0; i < reply.size(); i++) {
outputs->emplace(output_names_[i], reply[i]);
}
return SUCCESS;
}
Status MsModel::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) {
MS_EXCEPTION_IF_NULL(reply);
#ifdef ENABLE_D
if (context_ == nullptr) {
MS_LOG(ERROR) << "rtCtx is nullptr";
return FAILED;
}
rtError_t rt_ret = rtCtxSetCurrent(context_);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "set Ascend rtCtx failed";
return FAILED;
}
#endif
vector<tensor::TensorPtr> inputs;
for (size_t i = 0; i < request.size(); i++) {
auto &item = request[i];
auto input = inputs_[i];
if (input->Size() != item.DataSize()) {
MS_LOG(ERROR) << "Predict input " << i << " data size " << item.DataSize() << " not match model input data size "
<< input->Size();
return FAILED;
}
auto ret = memcpy_s(input->data_c(), input->Size(), item.Data(), item.DataSize());
if (ret != SUCCESS) {
MS_LOG(ERROR) << "Tensor copy failed";
return FAILED;
}
inputs.push_back(input);
}
vector<tensor::TensorPtr> outputs = RunGraph(inputs);
if (outputs.empty()) {
MS_LOG(ERROR) << "Execute Model Failed";
return FAILED;
}
reply->clear();
std::transform(outputs.begin(), outputs.end(), std::back_inserter(*reply),
[](const tensor::TensorPtr &tensor) { return Buffer(tensor->data_c(), tensor->Size()); });
return SUCCESS;
}
Status MsModel::FinalizeEnv() {
MS_LOG_INFO << "Start finalize env";
py::gil_scoped_acquire acquire;
session::ExecutorManager::Instance().Clear();
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
auto ms_context = MsContext::GetInstance();
if (ms_context == nullptr) {
MS_LOG(ERROR) << "Get Context failed!";
return FAILED;
}
if (!context::CloseTsd(ms_context)) {
MS_LOG(ERROR) << "Inference CloseTsd failed!";
return FAILED;
}
MS_LOG_INFO << "End finalize env";
return SUCCESS;
}
std::shared_ptr<FuncGraph> MsModel::LoadModel(const char *model_buf, size_t size, const std::string &device) {
MS_EXCEPTION_IF_NULL(model_buf);
try {
auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size);
return anf_graph;
} catch (std::exception &e) {
MS_LOG(ERROR) << "Inference LoadModel failed: " << e.what();
return nullptr;
}
}
void MsModel::RegAllOp() {
static std::mutex init_mutex;
static bool Initialized = false;
std::lock_guard<std::mutex> lock(init_mutex);
if (Initialized) {
return;
}
Initialized = true;
auto ms_context_instance = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context_instance);
ms_context_instance->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
try {
std::shared_ptr<py::scoped_interpreter> guard;
if (Py_IsInitialized() == 0) {
guard = std::make_shared<py::scoped_interpreter>();
}
py::module c_expression = py::module::import("mindspore._c_expression");
size_t ops_info_long = c_expression.attr("OpInfoLoaderPy")().attr("get_all_ops_info")().cast<size_t>();
auto all_ops_info = reinterpret_cast<std::vector<kernel::OpInfo *> *>(ops_info_long);
for (auto op_info : *all_ops_info) {
kernel::OpLib::RegOpInfo(std::shared_ptr<kernel::OpInfo>(op_info));
}
all_ops_info->clear();
delete all_ops_info;
} catch (const std::runtime_error &ex) {
MS_LOG_EXCEPTION << ex.what();
}
}
Status MsModel::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) {
MS_ASSERT(session_impl_ != nullptr);
try {
graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
py::gil_scoped_release gil_release;
return SUCCESS;
} catch (std::exception &e) {
MS_LOG(ERROR) << "Inference CompileGraph failed: " << e.what();
return FAILED;
}
}
std::vector<tensor::TensorPtr> MsModel::RunGraph(const std::vector<tensor::TensorPtr> &inputs) {
try {
VectorRef outputs;
session_impl_->RunGraph(graph_id_, inputs, &outputs);
return TransformVectorRefToMultiTensor(outputs);
} catch (std::exception &e) {
MS_LOG(ERROR) << "Inference Rungraph failed: " << e.what();
return std::vector<tensor::TensorPtr>();
}
}
Status MsModel::InitEnv(const std::unordered_map<std::string, std::string> &other_options) {
RegAllOp();
auto ms_context = MsContext::GetInstance();
if (ms_context == nullptr) {
MS_LOG(ERROR) << "Get Context failed!";
return FAILED;
}
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice);
if (!context::OpenTsd(ms_context)) {
MS_LOG(ERROR) << "Session init OpenTsd failed!";
return FAILED;
}
session_impl_ = session::SessionFactory::Get().Create(kDavinciInferenceDevice);
if (session_impl_ == nullptr) {
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kDavinciInferenceDevice
<< " is available.";
return FAILED;
}
session_impl_->Init(device_id_);
return SUCCESS;
}
Status MsModel::CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const {
MS_ASSERT(session_impl_ != nullptr);
std::string error_msg;
if (!session_impl_->CheckModelInputs(graph_id, inputs, &error_msg)) {
return Status(INVALID_INPUTS, error_msg);
}
return SUCCESS;
}
Status MsModel::GetInputsInfo(std::vector<Tensor> *tensor_list) const {
MS_EXCEPTION_IF_NULL(tensor_list);
tensor_list->clear();
for (size_t i = 0; i < inputs_.size(); i++) {
auto &tensor = inputs_[i];
Tensor infer_tensor;
infer_tensor.SetName(input_names_[i]);
infer_tensor.SetDataType(TransTypeId2InferDataType(tensor->data_type()));
infer_tensor.SetShape(tensor->shape());
tensor_list->push_back(infer_tensor);
}
return SUCCESS;
}
Status MsModel::GetOutputsInfo(std::vector<Tensor> *tensor_list) const {
MS_EXCEPTION_IF_NULL(tensor_list);
tensor_list->clear();
for (size_t i = 0; i < outputs_.size(); i++) {
auto &tensor = outputs_[i];
Tensor infer_tensor;
infer_tensor.SetName(output_names_[i]);
infer_tensor.SetDataType(TransTypeId2InferDataType(tensor->data_type()));
infer_tensor.SetShape(tensor->shape());
tensor_list->push_back(infer_tensor);
}
return SUCCESS;
}
} // namespace api
} // namespace mindspore

View File

@ -0,0 +1,85 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_SESSION_SESSION_H
#define MINDSPORE_CCSRC_SESSION_SESSION_H
#include <vector>
#include <string>
#include <unordered_map>
#include <utility>
#include <memory>
#include <map>
#include "backend/session/session_basic.h"
#include "ir/anf.h"
#include "include/api/status.h"
#include "cxx_api/model/model_impl.h"
#ifdef ENABLE_D
#include "runtime/context.h"
#endif
namespace mindspore {
namespace api {
class MsModel : public ModelImpl {
public:
explicit MsModel(uint32_t device_id);
~MsModel();
Status LoadModel(const Buffer &model_data, ModelType type,
const std::map<std::string, std::string> &options) override;
Status LoadModel(const std::string &file_name, ModelType type,
const std::map<std::string, std::string> &options) override;
Status UnloadModel() override;
Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
Status Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) override;
Status GetInputsInfo(std::vector<Tensor> *tensor_list) const override;
Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const override;
Status InitEnv(const std::unordered_map<std::string, std::string> &other_options);
Status FinalizeEnv();
private:
std::shared_ptr<session::SessionBasic> session_impl_ = nullptr;
uint32_t graph_id_;
std::string device_type_;
int32_t device_id_ = 0;
#ifdef ENABLE_D
rtContext_t context_ = nullptr;
#endif
std::vector<tensor::TensorPtr> inputs_;
std::vector<tensor::TensorPtr> outputs_;
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;
bool load_flag_ = false;
std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const std::string &device);
Buffer ReadFile(const std::string &file);
static void RegAllOp();
Status CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr);
Status CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const;
std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs);
Status ExecuteModel(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
};
API_REG_MODEL(AscendMS, MsModel);
} // namespace api
} // namespace mindspore
#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H