!3794 modify inference return type
Merge pull request !3794 from dinghao/master
This commit is contained in:
commit
2b56562770
|
@ -41,6 +41,7 @@ cmake-build-debug
|
||||||
*.pb.h
|
*.pb.h
|
||||||
*.pb.cc
|
*.pb.cc
|
||||||
*.pb
|
*.pb
|
||||||
|
*_grpc.py
|
||||||
|
|
||||||
# Object files
|
# Object files
|
||||||
*.o
|
*.o
|
||||||
|
|
|
@ -24,20 +24,20 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace inference {
|
namespace inference {
|
||||||
|
enum Status { SUCCESS = 0, FAILED, INVALID_INPUTS };
|
||||||
class MS_API InferSession {
|
class MS_API InferSession {
|
||||||
public:
|
public:
|
||||||
InferSession() = default;
|
InferSession() = default;
|
||||||
virtual ~InferSession() = default;
|
virtual ~InferSession() = default;
|
||||||
virtual bool InitEnv(const std::string &device_type, uint32_t device_id) = 0;
|
virtual Status InitEnv(const std::string &device_type, uint32_t device_id) = 0;
|
||||||
virtual bool FinalizeEnv() = 0;
|
virtual Status FinalizeEnv() = 0;
|
||||||
virtual bool LoadModelFromFile(const std::string &file_name, uint32_t &model_id) = 0;
|
virtual Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) = 0;
|
||||||
virtual bool UnloadModel(uint32_t model_id) = 0;
|
virtual Status UnloadModel(uint32_t model_id) = 0;
|
||||||
// override this method to avoid request/reply data copy
|
// override this method to avoid request/reply data copy
|
||||||
virtual bool ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) = 0;
|
virtual Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) = 0;
|
||||||
|
|
||||||
virtual bool ExecuteModel(uint32_t model_id, const std::vector<InferTensor> &inputs,
|
virtual Status ExecuteModel(uint32_t model_id, const std::vector<InferTensor> &inputs,
|
||||||
std::vector<InferTensor> &outputs) {
|
std::vector<InferTensor> &outputs) {
|
||||||
VectorInferTensorWrapRequest request(inputs);
|
VectorInferTensorWrapRequest request(inputs);
|
||||||
VectorInferTensorWrapReply reply(outputs);
|
VectorInferTensorWrapReply reply(outputs);
|
||||||
return ExecuteModel(model_id, request, reply);
|
return ExecuteModel(model_id, request, reply);
|
||||||
|
|
|
@ -37,8 +37,8 @@ namespace mindspore::inference {
|
||||||
std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &device, uint32_t device_id) {
|
std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &device, uint32_t device_id) {
|
||||||
try {
|
try {
|
||||||
auto session = std::make_shared<MSInferSession>();
|
auto session = std::make_shared<MSInferSession>();
|
||||||
bool ret = session->InitEnv(device, device_id);
|
Status ret = session->InitEnv(device, device_id);
|
||||||
if (!ret) {
|
if (ret != SUCCESS) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return session;
|
return session;
|
||||||
|
@ -84,21 +84,21 @@ std::shared_ptr<std::vector<char>> MSInferSession::ReadFile(const std::string &f
|
||||||
return buf;
|
return buf;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
|
Status MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
|
||||||
auto graphBuf = ReadFile(file_name);
|
auto graphBuf = ReadFile(file_name);
|
||||||
if (graphBuf == nullptr) {
|
if (graphBuf == nullptr) {
|
||||||
MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str();
|
MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str();
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
auto graph = LoadModel(graphBuf->data(), graphBuf->size(), device_type_);
|
auto graph = LoadModel(graphBuf->data(), graphBuf->size(), device_type_);
|
||||||
if (graph == nullptr) {
|
if (graph == nullptr) {
|
||||||
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
|
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
bool ret = CompileGraph(graph, model_id);
|
Status ret = CompileGraph(graph, model_id);
|
||||||
if (!ret) {
|
if (ret != SUCCESS) {
|
||||||
MS_LOG(ERROR) << "Compile graph model failed, file name is " << file_name.c_str();
|
MS_LOG(ERROR) << "Compile graph model failed, file name is " << file_name.c_str();
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Load model from file " << file_name << " success";
|
MS_LOG(INFO) << "Load model from file " << file_name << " success";
|
||||||
|
|
||||||
|
@ -107,14 +107,14 @@ bool MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &m
|
||||||
rtError_t rt_ret = rtCtxGetCurrent(&context_);
|
rtError_t rt_ret = rtCtxGetCurrent(&context_);
|
||||||
if (rt_ret != RT_ERROR_NONE || context_ == nullptr) {
|
if (rt_ret != RT_ERROR_NONE || context_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "the ascend device context is null";
|
MS_LOG(ERROR) << "the ascend device context is null";
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return true;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MSInferSession::UnloadModel(uint32_t model_id) { return true; }
|
Status MSInferSession::UnloadModel(uint32_t model_id) { return SUCCESS; }
|
||||||
|
|
||||||
tensor::TensorPtr ServingTensor2MSTensor(const InferTensorBase &out_tensor) {
|
tensor::TensorPtr ServingTensor2MSTensor(const InferTensorBase &out_tensor) {
|
||||||
std::vector<int> shape;
|
std::vector<int> shape;
|
||||||
|
@ -170,16 +170,16 @@ void MSTensor2ServingTensor(tensor::TensorPtr ms_tensor, InferTensorBase &out_te
|
||||||
out_tensor.set_data(ms_tensor->data_c(), ms_tensor->Size());
|
out_tensor.set_data(ms_tensor->data_c(), ms_tensor->Size());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) {
|
Status MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) {
|
||||||
#ifdef ENABLE_D
|
#ifdef ENABLE_D
|
||||||
if (context_ == nullptr) {
|
if (context_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "rtCtx is nullptr";
|
MS_LOG(ERROR) << "rtCtx is nullptr";
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
rtError_t rt_ret = rtCtxSetCurrent(context_);
|
rtError_t rt_ret = rtCtxSetCurrent(context_);
|
||||||
if (rt_ret != RT_ERROR_NONE) {
|
if (rt_ret != RT_ERROR_NONE) {
|
||||||
MS_LOG(ERROR) << "set Ascend rtCtx failed";
|
MS_LOG(ERROR) << "set Ascend rtCtx failed";
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -187,47 +187,47 @@ bool MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request,
|
||||||
for (size_t i = 0; i < request.size(); i++) {
|
for (size_t i = 0; i < request.size(); i++) {
|
||||||
if (request[i] == nullptr) {
|
if (request[i] == nullptr) {
|
||||||
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed, input tensor is null, index " << i;
|
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed, input tensor is null, index " << i;
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
auto input = ServingTensor2MSTensor(*request[i]);
|
auto input = ServingTensor2MSTensor(*request[i]);
|
||||||
if (input == nullptr) {
|
if (input == nullptr) {
|
||||||
MS_LOG(ERROR) << "Tensor convert failed";
|
MS_LOG(ERROR) << "Tensor convert failed";
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
inputs.push_back(input);
|
inputs.push_back(input);
|
||||||
}
|
}
|
||||||
if (!CheckModelInputs(model_id, inputs)) {
|
if (!CheckModelInputs(model_id, inputs)) {
|
||||||
MS_LOG(ERROR) << "Check Model " << model_id << " Inputs Failed";
|
MS_LOG(ERROR) << "Check Model " << model_id << " Inputs Failed";
|
||||||
return false;
|
return INVALID_INPUTS;
|
||||||
}
|
}
|
||||||
vector<tensor::TensorPtr> outputs = RunGraph(model_id, inputs);
|
vector<tensor::TensorPtr> outputs = RunGraph(model_id, inputs);
|
||||||
if (outputs.empty()) {
|
if (outputs.empty()) {
|
||||||
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed";
|
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed";
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
reply.clear();
|
reply.clear();
|
||||||
for (const auto &tensor : outputs) {
|
for (const auto &tensor : outputs) {
|
||||||
auto out_tensor = reply.add();
|
auto out_tensor = reply.add();
|
||||||
if (out_tensor == nullptr) {
|
if (out_tensor == nullptr) {
|
||||||
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed, add output tensor failed";
|
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed, add output tensor failed";
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
MSTensor2ServingTensor(tensor, *out_tensor);
|
MSTensor2ServingTensor(tensor, *out_tensor);
|
||||||
}
|
}
|
||||||
return true;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MSInferSession::FinalizeEnv() {
|
Status MSInferSession::FinalizeEnv() {
|
||||||
auto ms_context = MsContext::GetInstance();
|
auto ms_context = MsContext::GetInstance();
|
||||||
if (ms_context == nullptr) {
|
if (ms_context == nullptr) {
|
||||||
MS_LOG(ERROR) << "Get Context failed!";
|
MS_LOG(ERROR) << "Get Context failed!";
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
if (!ms_context->CloseTsd()) {
|
if (!ms_context->CloseTsd()) {
|
||||||
MS_LOG(ERROR) << "Inference CloseTsd failed!";
|
MS_LOG(ERROR) << "Inference CloseTsd failed!";
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
return true;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<FuncGraph> MSInferSession::LoadModel(const char *model_buf, size_t size, const std::string &device) {
|
std::shared_ptr<FuncGraph> MSInferSession::LoadModel(const char *model_buf, size_t size, const std::string &device) {
|
||||||
|
@ -292,16 +292,16 @@ void MSInferSession::RegAllOp() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MSInferSession::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id) {
|
Status MSInferSession::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id) {
|
||||||
MS_ASSERT(session_impl_ != nullptr);
|
MS_ASSERT(session_impl_ != nullptr);
|
||||||
try {
|
try {
|
||||||
auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
|
auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
|
||||||
py::gil_scoped_release gil_release;
|
py::gil_scoped_release gil_release;
|
||||||
model_id = graph_id;
|
model_id = graph_id;
|
||||||
return true;
|
return SUCCESS;
|
||||||
} catch (std::exception &e) {
|
} catch (std::exception &e) {
|
||||||
MS_LOG(ERROR) << "Inference CompileGraph failed";
|
MS_LOG(ERROR) << "Inference CompileGraph failed";
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -327,31 +327,31 @@ string MSInferSession::AjustTargetName(const std::string &device) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MSInferSession::InitEnv(const std::string &device, uint32_t device_id) {
|
Status MSInferSession::InitEnv(const std::string &device, uint32_t device_id) {
|
||||||
RegAllOp();
|
RegAllOp();
|
||||||
auto ms_context = MsContext::GetInstance();
|
auto ms_context = MsContext::GetInstance();
|
||||||
ms_context->set_execution_mode(kGraphMode);
|
ms_context->set_execution_mode(kGraphMode);
|
||||||
ms_context->set_device_id(device_id);
|
ms_context->set_device_id(device_id);
|
||||||
auto ajust_device = AjustTargetName(device);
|
auto ajust_device = AjustTargetName(device);
|
||||||
if (ajust_device == "") {
|
if (ajust_device == "") {
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
ms_context->set_device_target(device);
|
ms_context->set_device_target(device);
|
||||||
session_impl_ = session::SessionFactory::Get().Create(ajust_device);
|
session_impl_ = session::SessionFactory::Get().Create(ajust_device);
|
||||||
if (session_impl_ == nullptr) {
|
if (session_impl_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available.";
|
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available.";
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
session_impl_->Init(device_id);
|
session_impl_->Init(device_id);
|
||||||
if (ms_context == nullptr) {
|
if (ms_context == nullptr) {
|
||||||
MS_LOG(ERROR) << "Get Context failed!";
|
MS_LOG(ERROR) << "Get Context failed!";
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
if (!ms_context->OpenTsd()) {
|
if (!ms_context->OpenTsd()) {
|
||||||
MS_LOG(ERROR) << "Session init OpenTsd failed!";
|
MS_LOG(ERROR) << "Session init OpenTsd failed!";
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
return true;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MSInferSession::CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const {
|
bool MSInferSession::CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const {
|
||||||
|
|
|
@ -38,11 +38,11 @@ class MSInferSession : public InferSession {
|
||||||
MSInferSession();
|
MSInferSession();
|
||||||
~MSInferSession();
|
~MSInferSession();
|
||||||
|
|
||||||
bool InitEnv(const std::string &device_type, uint32_t device_id) override;
|
Status InitEnv(const std::string &device_type, uint32_t device_id) override;
|
||||||
bool FinalizeEnv() override;
|
Status FinalizeEnv() override;
|
||||||
bool LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override;
|
Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override;
|
||||||
bool UnloadModel(uint32_t model_id) override;
|
Status UnloadModel(uint32_t model_id) override;
|
||||||
bool ExecuteModel(uint32_t model_id, const RequestBase &inputs, ReplyBase &outputs) override;
|
Status ExecuteModel(uint32_t model_id, const RequestBase &inputs, ReplyBase &outputs) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<session::SessionBasic> session_impl_ = nullptr;
|
std::shared_ptr<session::SessionBasic> session_impl_ = nullptr;
|
||||||
|
@ -57,7 +57,7 @@ class MSInferSession : public InferSession {
|
||||||
std::shared_ptr<std::vector<char>> ReadFile(const std::string &file);
|
std::shared_ptr<std::vector<char>> ReadFile(const std::string &file);
|
||||||
static void RegAllOp();
|
static void RegAllOp();
|
||||||
string AjustTargetName(const std::string &device);
|
string AjustTargetName(const std::string &device);
|
||||||
bool CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id);
|
Status CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id);
|
||||||
bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const;
|
bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const;
|
||||||
std::vector<tensor::TensorPtr> RunGraph(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs);
|
std::vector<tensor::TensorPtr> RunGraph(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs);
|
||||||
};
|
};
|
||||||
|
|
|
@ -35,53 +35,53 @@ std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &dev
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AclSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
|
Status AclSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
|
||||||
return model_process_.LoadModelFromFile(file_name, model_id);
|
return model_process_.LoadModelFromFile(file_name, model_id) ? SUCCESS : FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AclSession::UnloadModel(uint32_t model_id) {
|
Status AclSession::UnloadModel(uint32_t model_id) {
|
||||||
model_process_.UnLoad();
|
model_process_.UnLoad();
|
||||||
return true;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AclSession::ExecuteModel(uint32_t model_id, const RequestBase &request,
|
Status AclSession::ExecuteModel(uint32_t model_id, const RequestBase &request,
|
||||||
ReplyBase &reply) { // set d context
|
ReplyBase &reply) { // set d context
|
||||||
aclError rt_ret = aclrtSetCurrentContext(context_);
|
aclError rt_ret = aclrtSetCurrentContext(context_);
|
||||||
if (rt_ret != ACL_ERROR_NONE) {
|
if (rt_ret != ACL_ERROR_NONE) {
|
||||||
MSI_LOG_ERROR << "set the ascend device context failed";
|
MSI_LOG_ERROR << "set the ascend device context failed";
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
return model_process_.Execute(request, reply);
|
return model_process_.Execute(request, reply) ? SUCCESS : FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AclSession::InitEnv(const std::string &device_type, uint32_t device_id) {
|
Status AclSession::InitEnv(const std::string &device_type, uint32_t device_id) {
|
||||||
device_type_ = device_type;
|
device_type_ = device_type;
|
||||||
device_id_ = device_id;
|
device_id_ = device_id;
|
||||||
auto ret = aclInit(nullptr);
|
auto ret = aclInit(nullptr);
|
||||||
if (ret != ACL_ERROR_NONE) {
|
if (ret != ACL_ERROR_NONE) {
|
||||||
MSI_LOG_ERROR << "Execute aclInit Failed";
|
MSI_LOG_ERROR << "Execute aclInit Failed";
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
MSI_LOG_INFO << "acl init success";
|
MSI_LOG_INFO << "acl init success";
|
||||||
|
|
||||||
ret = aclrtSetDevice(device_id_);
|
ret = aclrtSetDevice(device_id_);
|
||||||
if (ret != ACL_ERROR_NONE) {
|
if (ret != ACL_ERROR_NONE) {
|
||||||
MSI_LOG_ERROR << "acl open device " << device_id_ << " failed";
|
MSI_LOG_ERROR << "acl open device " << device_id_ << " failed";
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
MSI_LOG_INFO << "open device " << device_id_ << " success";
|
MSI_LOG_INFO << "open device " << device_id_ << " success";
|
||||||
|
|
||||||
ret = aclrtCreateContext(&context_, device_id_);
|
ret = aclrtCreateContext(&context_, device_id_);
|
||||||
if (ret != ACL_ERROR_NONE) {
|
if (ret != ACL_ERROR_NONE) {
|
||||||
MSI_LOG_ERROR << "acl create context failed";
|
MSI_LOG_ERROR << "acl create context failed";
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
MSI_LOG_INFO << "create context success";
|
MSI_LOG_INFO << "create context success";
|
||||||
|
|
||||||
ret = aclrtCreateStream(&stream_);
|
ret = aclrtCreateStream(&stream_);
|
||||||
if (ret != ACL_ERROR_NONE) {
|
if (ret != ACL_ERROR_NONE) {
|
||||||
MSI_LOG_ERROR << "acl create stream failed";
|
MSI_LOG_ERROR << "acl create stream failed";
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
MSI_LOG_INFO << "create stream success";
|
MSI_LOG_INFO << "create stream success";
|
||||||
|
|
||||||
|
@ -89,17 +89,17 @@ bool AclSession::InitEnv(const std::string &device_type, uint32_t device_id) {
|
||||||
ret = aclrtGetRunMode(&run_mode);
|
ret = aclrtGetRunMode(&run_mode);
|
||||||
if (ret != ACL_ERROR_NONE) {
|
if (ret != ACL_ERROR_NONE) {
|
||||||
MSI_LOG_ERROR << "acl get run mode failed";
|
MSI_LOG_ERROR << "acl get run mode failed";
|
||||||
return false;
|
return FAILED;
|
||||||
}
|
}
|
||||||
bool is_device = (run_mode == ACL_DEVICE);
|
bool is_device = (run_mode == ACL_DEVICE);
|
||||||
model_process_.SetIsDevice(is_device);
|
model_process_.SetIsDevice(is_device);
|
||||||
MSI_LOG_INFO << "get run mode success is device input/output " << is_device;
|
MSI_LOG_INFO << "get run mode success is device input/output " << is_device;
|
||||||
|
|
||||||
MSI_LOG_INFO << "Init acl success, device id " << device_id_;
|
MSI_LOG_INFO << "Init acl success, device id " << device_id_;
|
||||||
return true;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AclSession::FinalizeEnv() {
|
Status AclSession::FinalizeEnv() {
|
||||||
aclError ret;
|
aclError ret;
|
||||||
if (stream_ != nullptr) {
|
if (stream_ != nullptr) {
|
||||||
ret = aclrtDestroyStream(stream_);
|
ret = aclrtDestroyStream(stream_);
|
||||||
|
@ -129,7 +129,7 @@ bool AclSession::FinalizeEnv() {
|
||||||
MSI_LOG_ERROR << "finalize acl failed";
|
MSI_LOG_ERROR << "finalize acl failed";
|
||||||
}
|
}
|
||||||
MSI_LOG_INFO << "end to finalize acl";
|
MSI_LOG_INFO << "end to finalize acl";
|
||||||
return true;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
AclSession::AclSession() = default;
|
AclSession::AclSession() = default;
|
||||||
|
|
|
@ -32,11 +32,11 @@ class AclSession : public InferSession {
|
||||||
public:
|
public:
|
||||||
AclSession();
|
AclSession();
|
||||||
|
|
||||||
bool InitEnv(const std::string &device_type, uint32_t device_id) override;
|
Status InitEnv(const std::string &device_type, uint32_t device_id) override;
|
||||||
bool FinalizeEnv() override;
|
Status FinalizeEnv() override;
|
||||||
bool LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override;
|
Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override;
|
||||||
bool UnloadModel(uint32_t model_id) override;
|
Status UnloadModel(uint32_t model_id) override;
|
||||||
bool ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) override;
|
Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string device_type_;
|
std::string device_type_;
|
||||||
|
|
|
@ -31,6 +31,7 @@
|
||||||
#include "core/version_control/version_controller.h"
|
#include "core/version_control/version_controller.h"
|
||||||
#include "core/util/file_system_operation.h"
|
#include "core/util/file_system_operation.h"
|
||||||
#include "core/serving_tensor.h"
|
#include "core/serving_tensor.h"
|
||||||
|
#include "util/status.h"
|
||||||
|
|
||||||
using ms_serving::MSService;
|
using ms_serving::MSService;
|
||||||
using ms_serving::PredictReply;
|
using ms_serving::PredictReply;
|
||||||
|
@ -79,9 +80,9 @@ Status Session::Predict(const PredictRequest &request, PredictReply &reply) {
|
||||||
|
|
||||||
auto ret = session_->ExecuteModel(graph_id_, serving_request, serving_reply);
|
auto ret = session_->ExecuteModel(graph_id_, serving_request, serving_reply);
|
||||||
MSI_LOG(INFO) << "run Predict finished";
|
MSI_LOG(INFO) << "run Predict finished";
|
||||||
if (!ret) {
|
if (Status(ret) != SUCCESS) {
|
||||||
MSI_LOG(ERROR) << "execute model return failed";
|
MSI_LOG(ERROR) << "execute model return failed";
|
||||||
return FAILED;
|
return Status(ret);
|
||||||
}
|
}
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
@ -97,9 +98,9 @@ Status Session::Warmup(const MindSporeModelPtr model) {
|
||||||
MSI_TIME_STAMP_START(LoadModelFromFile)
|
MSI_TIME_STAMP_START(LoadModelFromFile)
|
||||||
auto ret = session_->LoadModelFromFile(file_name, graph_id_);
|
auto ret = session_->LoadModelFromFile(file_name, graph_id_);
|
||||||
MSI_TIME_STAMP_END(LoadModelFromFile)
|
MSI_TIME_STAMP_END(LoadModelFromFile)
|
||||||
if (!ret) {
|
if (Status(ret) != SUCCESS) {
|
||||||
MSI_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
|
MSI_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
|
||||||
return FAILED;
|
return Status(ret);
|
||||||
}
|
}
|
||||||
model_loaded_ = true;
|
model_loaded_ = true;
|
||||||
MSI_LOG(INFO) << "Session Warmup finished";
|
MSI_LOG(INFO) << "Session Warmup finished";
|
||||||
|
@ -119,12 +120,22 @@ namespace {
|
||||||
static const uint32_t uint32max = 0x7FFFFFFF;
|
static const uint32_t uint32max = 0x7FFFFFFF;
|
||||||
std::promise<void> exit_requested;
|
std::promise<void> exit_requested;
|
||||||
|
|
||||||
void ClearEnv() {
|
void ClearEnv() { Session::Instance().Clear(); }
|
||||||
Session::Instance().Clear();
|
|
||||||
// inference::ExitInference();
|
|
||||||
}
|
|
||||||
void HandleSignal(int sig) { exit_requested.set_value(); }
|
void HandleSignal(int sig) { exit_requested.set_value(); }
|
||||||
|
|
||||||
|
grpc::Status CreatGRPCStatus(Status status) {
|
||||||
|
switch (status) {
|
||||||
|
case SUCCESS:
|
||||||
|
return grpc::Status::OK;
|
||||||
|
case FAILED:
|
||||||
|
return grpc::Status::CANCELLED;
|
||||||
|
case INVALID_INPUTS:
|
||||||
|
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "The Predict Inputs do not match the Model Request!");
|
||||||
|
default:
|
||||||
|
return grpc::Status::CANCELLED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Service Implement
|
// Service Implement
|
||||||
|
@ -134,8 +145,8 @@ class MSServiceImpl final : public MSService::Service {
|
||||||
MSI_TIME_STAMP_START(Predict)
|
MSI_TIME_STAMP_START(Predict)
|
||||||
auto res = Session::Instance().Predict(*request, *reply);
|
auto res = Session::Instance().Predict(*request, *reply);
|
||||||
MSI_TIME_STAMP_END(Predict)
|
MSI_TIME_STAMP_END(Predict)
|
||||||
if (res != SUCCESS) {
|
if (res != inference::SUCCESS) {
|
||||||
return grpc::Status::CANCELLED;
|
return CreatGRPCStatus(res);
|
||||||
}
|
}
|
||||||
MSI_LOG(INFO) << "Finish call service Eval";
|
MSI_LOG(INFO) << "Finish call service Eval";
|
||||||
return grpc::Status::OK;
|
return grpc::Status::OK;
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace serving {
|
namespace serving {
|
||||||
using Status = uint32_t;
|
using Status = uint32_t;
|
||||||
enum ServingStatus { SUCCESS = 0, FAILED };
|
enum ServingStatus { SUCCESS = 0, FAILED, INVALID_INPUTS };
|
||||||
} // namespace serving
|
} // namespace serving
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -31,51 +31,51 @@ using ms_serving::TensorShape;
|
||||||
|
|
||||||
class MSClient {
|
class MSClient {
|
||||||
public:
|
public:
|
||||||
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
|
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
|
||||||
|
|
||||||
~MSClient() = default;
|
~MSClient() = default;
|
||||||
|
|
||||||
std::string Predict() {
|
std::string Predict() {
|
||||||
// Data we are sending to the server.
|
// Data we are sending to the server.
|
||||||
PredictRequest request;
|
PredictRequest request;
|
||||||
|
|
||||||
Tensor data;
|
Tensor data;
|
||||||
TensorShape shape;
|
TensorShape shape;
|
||||||
shape.add_dims(4);
|
shape.add_dims(4);
|
||||||
*data.mutable_tensor_shape() = shape;
|
*data.mutable_tensor_shape() = shape;
|
||||||
data.set_tensor_type(ms_serving::MS_FLOAT32);
|
data.set_tensor_type(ms_serving::MS_FLOAT32);
|
||||||
std::vector<float> input_data{1, 2, 3, 4};
|
std::vector<float> input_data{1, 2, 3, 4};
|
||||||
data.set_data(input_data.data(), input_data.size() * sizeof(float));
|
data.set_data(input_data.data(), input_data.size() * sizeof(float));
|
||||||
*request.add_data() = data;
|
*request.add_data() = data;
|
||||||
*request.add_data() = data;
|
*request.add_data() = data;
|
||||||
std::cout << "intput tensor size is " << request.data_size() << std::endl;
|
std::cout << "intput tensor size is " << request.data_size() << std::endl;
|
||||||
// Container for the data we expect from the server.
|
// Container for the data we expect from the server.
|
||||||
PredictReply reply;
|
PredictReply reply;
|
||||||
|
|
||||||
// Context for the client. It could be used to convey extra information to
|
// Context for the client. It could be used to convey extra information to
|
||||||
// the server and/or tweak certain RPC behaviors.
|
// the server and/or tweak certain RPC behaviors.
|
||||||
ClientContext context;
|
ClientContext context;
|
||||||
|
|
||||||
// The actual RPC.
|
// The actual RPC.
|
||||||
Status status = stub_->Predict(&context, request, &reply);
|
Status status = stub_->Predict(&context, request, &reply);
|
||||||
std::cout << "Compute [1, 2, 3, 4] + [1, 2, 3, 4]" << std::endl;
|
std::cout << "Compute [1, 2, 3, 4] + [1, 2, 3, 4]" << std::endl;
|
||||||
|
|
||||||
|
// Act upon its status.
|
||||||
|
if (status.ok()) {
|
||||||
std::cout << "Add result is";
|
std::cout << "Add result is";
|
||||||
for (size_t i = 0; i < reply.result(0).data().size() / sizeof(float); i++) {
|
for (size_t i = 0; i < reply.result(0).data().size() / sizeof(float); i++) {
|
||||||
std::cout << " " << (reinterpret_cast<const float *>(reply.mutable_result(0)->mutable_data()->data()))[i];
|
std::cout << " " << (reinterpret_cast<const float *>(reply.mutable_result(0)->mutable_data()->data()))[i];
|
||||||
}
|
}
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
|
return "RPC OK";
|
||||||
// Act upon its status.
|
} else {
|
||||||
if (status.ok()) {
|
std::cout << status.error_code() << ": " << status.error_message() << std::endl;
|
||||||
return "RPC OK";
|
return "RPC failed";
|
||||||
} else {
|
|
||||||
std::cout << status.error_code() << ": " << status.error_message() << std::endl;
|
|
||||||
return "RPC failed";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<MSService::Stub> stub_;
|
std::unique_ptr<MSService::Stub> stub_;
|
||||||
};
|
};
|
||||||
|
|
||||||
int main(int argc, char **argv) {
|
int main(int argc, char **argv) {
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
import sys
|
||||||
import grpc
|
import grpc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import ms_service_pb2
|
import ms_service_pb2
|
||||||
|
@ -19,7 +20,19 @@ import ms_service_pb2_grpc
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
channel = grpc.insecure_channel('localhost:5500')
|
if len(sys.argv) > 2:
|
||||||
|
sys.exit("input error")
|
||||||
|
channel_str = ""
|
||||||
|
if len(sys.argv) == 2:
|
||||||
|
split_args = sys.argv[1].split('=')
|
||||||
|
if len(split_args) > 1:
|
||||||
|
channel_str = split_args[1]
|
||||||
|
else:
|
||||||
|
channel_str = 'localhost:5500'
|
||||||
|
else:
|
||||||
|
channel_str = 'localhost:5500'
|
||||||
|
|
||||||
|
channel = grpc.insecure_channel(channel_str)
|
||||||
stub = ms_service_pb2_grpc.MSServiceStub(channel)
|
stub = ms_service_pb2_grpc.MSServiceStub(channel)
|
||||||
request = ms_service_pb2.PredictRequest()
|
request = ms_service_pb2.PredictRequest()
|
||||||
|
|
||||||
|
@ -33,11 +46,17 @@ def run():
|
||||||
y.tensor_type = ms_service_pb2.MS_FLOAT32
|
y.tensor_type = ms_service_pb2.MS_FLOAT32
|
||||||
y.data = (np.ones([4]).astype(np.float32)).tobytes()
|
y.data = (np.ones([4]).astype(np.float32)).tobytes()
|
||||||
|
|
||||||
result = stub.Predict(request)
|
try:
|
||||||
print(result)
|
result = stub.Predict(request)
|
||||||
result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims)
|
print(result)
|
||||||
print("ms client received: ")
|
result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims)
|
||||||
print(result_np)
|
print("ms client received: ")
|
||||||
|
print(result_np)
|
||||||
|
except grpc.RpcError as e:
|
||||||
|
print(e.details())
|
||||||
|
status_code = e.code()
|
||||||
|
print(status_code.name)
|
||||||
|
print(status_code.value)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run()
|
run()
|
||||||
|
|
Loading…
Reference in New Issue