!3794 modify inference return type

Merge pull request !3794 from dinghao/master
This commit is contained in:
mindspore-ci-bot 2020-08-03 09:06:06 +08:00 committed by Gitee
commit 2b56562770
11 changed files with 149 additions and 118 deletions

1
.gitignore vendored
View File

@ -41,6 +41,7 @@ cmake-build-debug
*.pb.h *.pb.h
*.pb.cc *.pb.cc
*.pb *.pb
*_grpc.py
# Object files # Object files
*.o *.o

View File

@ -24,20 +24,20 @@
namespace mindspore { namespace mindspore {
namespace inference { namespace inference {
enum Status { SUCCESS = 0, FAILED, INVALID_INPUTS };
class MS_API InferSession { class MS_API InferSession {
public: public:
InferSession() = default; InferSession() = default;
virtual ~InferSession() = default; virtual ~InferSession() = default;
virtual bool InitEnv(const std::string &device_type, uint32_t device_id) = 0; virtual Status InitEnv(const std::string &device_type, uint32_t device_id) = 0;
virtual bool FinalizeEnv() = 0; virtual Status FinalizeEnv() = 0;
virtual bool LoadModelFromFile(const std::string &file_name, uint32_t &model_id) = 0; virtual Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) = 0;
virtual bool UnloadModel(uint32_t model_id) = 0; virtual Status UnloadModel(uint32_t model_id) = 0;
// override this method to avoid request/reply data copy // override this method to avoid request/reply data copy
virtual bool ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) = 0; virtual Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) = 0;
virtual bool ExecuteModel(uint32_t model_id, const std::vector<InferTensor> &inputs, virtual Status ExecuteModel(uint32_t model_id, const std::vector<InferTensor> &inputs,
std::vector<InferTensor> &outputs) { std::vector<InferTensor> &outputs) {
VectorInferTensorWrapRequest request(inputs); VectorInferTensorWrapRequest request(inputs);
VectorInferTensorWrapReply reply(outputs); VectorInferTensorWrapReply reply(outputs);
return ExecuteModel(model_id, request, reply); return ExecuteModel(model_id, request, reply);

View File

@ -37,8 +37,8 @@ namespace mindspore::inference {
std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &device, uint32_t device_id) { std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &device, uint32_t device_id) {
try { try {
auto session = std::make_shared<MSInferSession>(); auto session = std::make_shared<MSInferSession>();
bool ret = session->InitEnv(device, device_id); Status ret = session->InitEnv(device, device_id);
if (!ret) { if (ret != SUCCESS) {
return nullptr; return nullptr;
} }
return session; return session;
@ -84,21 +84,21 @@ std::shared_ptr<std::vector<char>> MSInferSession::ReadFile(const std::string &f
return buf; return buf;
} }
bool MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) { Status MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
auto graphBuf = ReadFile(file_name); auto graphBuf = ReadFile(file_name);
if (graphBuf == nullptr) { if (graphBuf == nullptr) {
MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str(); MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str();
return false; return FAILED;
} }
auto graph = LoadModel(graphBuf->data(), graphBuf->size(), device_type_); auto graph = LoadModel(graphBuf->data(), graphBuf->size(), device_type_);
if (graph == nullptr) { if (graph == nullptr) {
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
return false; return FAILED;
} }
bool ret = CompileGraph(graph, model_id); Status ret = CompileGraph(graph, model_id);
if (!ret) { if (ret != SUCCESS) {
MS_LOG(ERROR) << "Compile graph model failed, file name is " << file_name.c_str(); MS_LOG(ERROR) << "Compile graph model failed, file name is " << file_name.c_str();
return false; return FAILED;
} }
MS_LOG(INFO) << "Load model from file " << file_name << " success"; MS_LOG(INFO) << "Load model from file " << file_name << " success";
@ -107,14 +107,14 @@ bool MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &m
rtError_t rt_ret = rtCtxGetCurrent(&context_); rtError_t rt_ret = rtCtxGetCurrent(&context_);
if (rt_ret != RT_ERROR_NONE || context_ == nullptr) { if (rt_ret != RT_ERROR_NONE || context_ == nullptr) {
MS_LOG(ERROR) << "the ascend device context is null"; MS_LOG(ERROR) << "the ascend device context is null";
return false; return FAILED;
} }
#endif #endif
return true; return SUCCESS;
} }
bool MSInferSession::UnloadModel(uint32_t model_id) { return true; } Status MSInferSession::UnloadModel(uint32_t model_id) { return SUCCESS; }
tensor::TensorPtr ServingTensor2MSTensor(const InferTensorBase &out_tensor) { tensor::TensorPtr ServingTensor2MSTensor(const InferTensorBase &out_tensor) {
std::vector<int> shape; std::vector<int> shape;
@ -170,16 +170,16 @@ void MSTensor2ServingTensor(tensor::TensorPtr ms_tensor, InferTensorBase &out_te
out_tensor.set_data(ms_tensor->data_c(), ms_tensor->Size()); out_tensor.set_data(ms_tensor->data_c(), ms_tensor->Size());
} }
bool MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) { Status MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) {
#ifdef ENABLE_D #ifdef ENABLE_D
if (context_ == nullptr) { if (context_ == nullptr) {
MS_LOG(ERROR) << "rtCtx is nullptr"; MS_LOG(ERROR) << "rtCtx is nullptr";
return false; return FAILED;
} }
rtError_t rt_ret = rtCtxSetCurrent(context_); rtError_t rt_ret = rtCtxSetCurrent(context_);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "set Ascend rtCtx failed"; MS_LOG(ERROR) << "set Ascend rtCtx failed";
return false; return FAILED;
} }
#endif #endif
@ -187,47 +187,47 @@ bool MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request,
for (size_t i = 0; i < request.size(); i++) { for (size_t i = 0; i < request.size(); i++) {
if (request[i] == nullptr) { if (request[i] == nullptr) {
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed input tensor is null, index " << i; MS_LOG(ERROR) << "Execute Model " << model_id << " Failed input tensor is null, index " << i;
return false; return FAILED;
} }
auto input = ServingTensor2MSTensor(*request[i]); auto input = ServingTensor2MSTensor(*request[i]);
if (input == nullptr) { if (input == nullptr) {
MS_LOG(ERROR) << "Tensor convert failed"; MS_LOG(ERROR) << "Tensor convert failed";
return false; return FAILED;
} }
inputs.push_back(input); inputs.push_back(input);
} }
if (!CheckModelInputs(model_id, inputs)) { if (!CheckModelInputs(model_id, inputs)) {
MS_LOG(ERROR) << "Check Model " << model_id << " Inputs Failed"; MS_LOG(ERROR) << "Check Model " << model_id << " Inputs Failed";
return false; return INVALID_INPUTS;
} }
vector<tensor::TensorPtr> outputs = RunGraph(model_id, inputs); vector<tensor::TensorPtr> outputs = RunGraph(model_id, inputs);
if (outputs.empty()) { if (outputs.empty()) {
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed"; MS_LOG(ERROR) << "Execute Model " << model_id << " Failed";
return false; return FAILED;
} }
reply.clear(); reply.clear();
for (const auto &tensor : outputs) { for (const auto &tensor : outputs) {
auto out_tensor = reply.add(); auto out_tensor = reply.add();
if (out_tensor == nullptr) { if (out_tensor == nullptr) {
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed add output tensor failed"; MS_LOG(ERROR) << "Execute Model " << model_id << " Failed add output tensor failed";
return false; return FAILED;
} }
MSTensor2ServingTensor(tensor, *out_tensor); MSTensor2ServingTensor(tensor, *out_tensor);
} }
return true; return SUCCESS;
} }
bool MSInferSession::FinalizeEnv() { Status MSInferSession::FinalizeEnv() {
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
if (ms_context == nullptr) { if (ms_context == nullptr) {
MS_LOG(ERROR) << "Get Context failed!"; MS_LOG(ERROR) << "Get Context failed!";
return false; return FAILED;
} }
if (!ms_context->CloseTsd()) { if (!ms_context->CloseTsd()) {
MS_LOG(ERROR) << "Inference CloseTsd failed!"; MS_LOG(ERROR) << "Inference CloseTsd failed!";
return false; return FAILED;
} }
return true; return SUCCESS;
} }
std::shared_ptr<FuncGraph> MSInferSession::LoadModel(const char *model_buf, size_t size, const std::string &device) { std::shared_ptr<FuncGraph> MSInferSession::LoadModel(const char *model_buf, size_t size, const std::string &device) {
@ -292,16 +292,16 @@ void MSInferSession::RegAllOp() {
return; return;
} }
bool MSInferSession::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id) { Status MSInferSession::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id) {
MS_ASSERT(session_impl_ != nullptr); MS_ASSERT(session_impl_ != nullptr);
try { try {
auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
py::gil_scoped_release gil_release; py::gil_scoped_release gil_release;
model_id = graph_id; model_id = graph_id;
return true; return SUCCESS;
} catch (std::exception &e) { } catch (std::exception &e) {
MS_LOG(ERROR) << "Inference CompileGraph failed"; MS_LOG(ERROR) << "Inference CompileGraph failed";
return false; return FAILED;
} }
} }
@ -327,31 +327,31 @@ string MSInferSession::AjustTargetName(const std::string &device) {
} }
} }
bool MSInferSession::InitEnv(const std::string &device, uint32_t device_id) { Status MSInferSession::InitEnv(const std::string &device, uint32_t device_id) {
RegAllOp(); RegAllOp();
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
ms_context->set_execution_mode(kGraphMode); ms_context->set_execution_mode(kGraphMode);
ms_context->set_device_id(device_id); ms_context->set_device_id(device_id);
auto ajust_device = AjustTargetName(device); auto ajust_device = AjustTargetName(device);
if (ajust_device == "") { if (ajust_device == "") {
return false; return FAILED;
} }
ms_context->set_device_target(device); ms_context->set_device_target(device);
session_impl_ = session::SessionFactory::Get().Create(ajust_device); session_impl_ = session::SessionFactory::Get().Create(ajust_device);
if (session_impl_ == nullptr) { if (session_impl_ == nullptr) {
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available."; MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available.";
return false; return FAILED;
} }
session_impl_->Init(device_id); session_impl_->Init(device_id);
if (ms_context == nullptr) { if (ms_context == nullptr) {
MS_LOG(ERROR) << "Get Context failed!"; MS_LOG(ERROR) << "Get Context failed!";
return false; return FAILED;
} }
if (!ms_context->OpenTsd()) { if (!ms_context->OpenTsd()) {
MS_LOG(ERROR) << "Session init OpenTsd failed!"; MS_LOG(ERROR) << "Session init OpenTsd failed!";
return false; return FAILED;
} }
return true; return SUCCESS;
} }
bool MSInferSession::CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const { bool MSInferSession::CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const {

View File

@ -38,11 +38,11 @@ class MSInferSession : public InferSession {
MSInferSession(); MSInferSession();
~MSInferSession(); ~MSInferSession();
bool InitEnv(const std::string &device_type, uint32_t device_id) override; Status InitEnv(const std::string &device_type, uint32_t device_id) override;
bool FinalizeEnv() override; Status FinalizeEnv() override;
bool LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override; Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override;
bool UnloadModel(uint32_t model_id) override; Status UnloadModel(uint32_t model_id) override;
bool ExecuteModel(uint32_t model_id, const RequestBase &inputs, ReplyBase &outputs) override; Status ExecuteModel(uint32_t model_id, const RequestBase &inputs, ReplyBase &outputs) override;
private: private:
std::shared_ptr<session::SessionBasic> session_impl_ = nullptr; std::shared_ptr<session::SessionBasic> session_impl_ = nullptr;
@ -57,7 +57,7 @@ class MSInferSession : public InferSession {
std::shared_ptr<std::vector<char>> ReadFile(const std::string &file); std::shared_ptr<std::vector<char>> ReadFile(const std::string &file);
static void RegAllOp(); static void RegAllOp();
string AjustTargetName(const std::string &device); string AjustTargetName(const std::string &device);
bool CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id); Status CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id);
bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const; bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const;
std::vector<tensor::TensorPtr> RunGraph(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs); std::vector<tensor::TensorPtr> RunGraph(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs);
}; };

View File

@ -35,53 +35,53 @@ std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &dev
} }
} }
bool AclSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) { Status AclSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
return model_process_.LoadModelFromFile(file_name, model_id); return model_process_.LoadModelFromFile(file_name, model_id) ? SUCCESS : FAILED;
} }
bool AclSession::UnloadModel(uint32_t model_id) { Status AclSession::UnloadModel(uint32_t model_id) {
model_process_.UnLoad(); model_process_.UnLoad();
return true; return SUCCESS;
} }
bool AclSession::ExecuteModel(uint32_t model_id, const RequestBase &request, Status AclSession::ExecuteModel(uint32_t model_id, const RequestBase &request,
ReplyBase &reply) { // set d context ReplyBase &reply) { // set d context
aclError rt_ret = aclrtSetCurrentContext(context_); aclError rt_ret = aclrtSetCurrentContext(context_);
if (rt_ret != ACL_ERROR_NONE) { if (rt_ret != ACL_ERROR_NONE) {
MSI_LOG_ERROR << "set the ascend device context failed"; MSI_LOG_ERROR << "set the ascend device context failed";
return false; return FAILED;
} }
return model_process_.Execute(request, reply); return model_process_.Execute(request, reply) ? SUCCESS : FAILED;
} }
bool AclSession::InitEnv(const std::string &device_type, uint32_t device_id) { Status AclSession::InitEnv(const std::string &device_type, uint32_t device_id) {
device_type_ = device_type; device_type_ = device_type;
device_id_ = device_id; device_id_ = device_id;
auto ret = aclInit(nullptr); auto ret = aclInit(nullptr);
if (ret != ACL_ERROR_NONE) { if (ret != ACL_ERROR_NONE) {
MSI_LOG_ERROR << "Execute aclInit Failed"; MSI_LOG_ERROR << "Execute aclInit Failed";
return false; return FAILED;
} }
MSI_LOG_INFO << "acl init success"; MSI_LOG_INFO << "acl init success";
ret = aclrtSetDevice(device_id_); ret = aclrtSetDevice(device_id_);
if (ret != ACL_ERROR_NONE) { if (ret != ACL_ERROR_NONE) {
MSI_LOG_ERROR << "acl open device " << device_id_ << " failed"; MSI_LOG_ERROR << "acl open device " << device_id_ << " failed";
return false; return FAILED;
} }
MSI_LOG_INFO << "open device " << device_id_ << " success"; MSI_LOG_INFO << "open device " << device_id_ << " success";
ret = aclrtCreateContext(&context_, device_id_); ret = aclrtCreateContext(&context_, device_id_);
if (ret != ACL_ERROR_NONE) { if (ret != ACL_ERROR_NONE) {
MSI_LOG_ERROR << "acl create context failed"; MSI_LOG_ERROR << "acl create context failed";
return false; return FAILED;
} }
MSI_LOG_INFO << "create context success"; MSI_LOG_INFO << "create context success";
ret = aclrtCreateStream(&stream_); ret = aclrtCreateStream(&stream_);
if (ret != ACL_ERROR_NONE) { if (ret != ACL_ERROR_NONE) {
MSI_LOG_ERROR << "acl create stream failed"; MSI_LOG_ERROR << "acl create stream failed";
return false; return FAILED;
} }
MSI_LOG_INFO << "create stream success"; MSI_LOG_INFO << "create stream success";
@ -89,17 +89,17 @@ bool AclSession::InitEnv(const std::string &device_type, uint32_t device_id) {
ret = aclrtGetRunMode(&run_mode); ret = aclrtGetRunMode(&run_mode);
if (ret != ACL_ERROR_NONE) { if (ret != ACL_ERROR_NONE) {
MSI_LOG_ERROR << "acl get run mode failed"; MSI_LOG_ERROR << "acl get run mode failed";
return false; return FAILED;
} }
bool is_device = (run_mode == ACL_DEVICE); bool is_device = (run_mode == ACL_DEVICE);
model_process_.SetIsDevice(is_device); model_process_.SetIsDevice(is_device);
MSI_LOG_INFO << "get run mode success is device input/output " << is_device; MSI_LOG_INFO << "get run mode success is device input/output " << is_device;
MSI_LOG_INFO << "Init acl success, device id " << device_id_; MSI_LOG_INFO << "Init acl success, device id " << device_id_;
return true; return SUCCESS;
} }
bool AclSession::FinalizeEnv() { Status AclSession::FinalizeEnv() {
aclError ret; aclError ret;
if (stream_ != nullptr) { if (stream_ != nullptr) {
ret = aclrtDestroyStream(stream_); ret = aclrtDestroyStream(stream_);
@ -129,7 +129,7 @@ bool AclSession::FinalizeEnv() {
MSI_LOG_ERROR << "finalize acl failed"; MSI_LOG_ERROR << "finalize acl failed";
} }
MSI_LOG_INFO << "end to finalize acl"; MSI_LOG_INFO << "end to finalize acl";
return true; return SUCCESS;
} }
AclSession::AclSession() = default; AclSession::AclSession() = default;

View File

@ -32,11 +32,11 @@ class AclSession : public InferSession {
public: public:
AclSession(); AclSession();
bool InitEnv(const std::string &device_type, uint32_t device_id) override; Status InitEnv(const std::string &device_type, uint32_t device_id) override;
bool FinalizeEnv() override; Status FinalizeEnv() override;
bool LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override; Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override;
bool UnloadModel(uint32_t model_id) override; Status UnloadModel(uint32_t model_id) override;
bool ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) override; Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) override;
private: private:
std::string device_type_; std::string device_type_;

View File

@ -31,6 +31,7 @@
#include "core/version_control/version_controller.h" #include "core/version_control/version_controller.h"
#include "core/util/file_system_operation.h" #include "core/util/file_system_operation.h"
#include "core/serving_tensor.h" #include "core/serving_tensor.h"
#include "util/status.h"
using ms_serving::MSService; using ms_serving::MSService;
using ms_serving::PredictReply; using ms_serving::PredictReply;
@ -79,9 +80,9 @@ Status Session::Predict(const PredictRequest &request, PredictReply &reply) {
auto ret = session_->ExecuteModel(graph_id_, serving_request, serving_reply); auto ret = session_->ExecuteModel(graph_id_, serving_request, serving_reply);
MSI_LOG(INFO) << "run Predict finished"; MSI_LOG(INFO) << "run Predict finished";
if (!ret) { if (Status(ret) != SUCCESS) {
MSI_LOG(ERROR) << "execute model return failed"; MSI_LOG(ERROR) << "execute model return failed";
return FAILED; return Status(ret);
} }
return SUCCESS; return SUCCESS;
} }
@ -97,9 +98,9 @@ Status Session::Warmup(const MindSporeModelPtr model) {
MSI_TIME_STAMP_START(LoadModelFromFile) MSI_TIME_STAMP_START(LoadModelFromFile)
auto ret = session_->LoadModelFromFile(file_name, graph_id_); auto ret = session_->LoadModelFromFile(file_name, graph_id_);
MSI_TIME_STAMP_END(LoadModelFromFile) MSI_TIME_STAMP_END(LoadModelFromFile)
if (!ret) { if (Status(ret) != SUCCESS) {
MSI_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); MSI_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
return FAILED; return Status(ret);
} }
model_loaded_ = true; model_loaded_ = true;
MSI_LOG(INFO) << "Session Warmup finished"; MSI_LOG(INFO) << "Session Warmup finished";
@ -119,12 +120,22 @@ namespace {
static const uint32_t uint32max = 0x7FFFFFFF; static const uint32_t uint32max = 0x7FFFFFFF;
std::promise<void> exit_requested; std::promise<void> exit_requested;
void ClearEnv() { void ClearEnv() { Session::Instance().Clear(); }
Session::Instance().Clear();
// inference::ExitInference();
}
void HandleSignal(int sig) { exit_requested.set_value(); } void HandleSignal(int sig) { exit_requested.set_value(); }
grpc::Status CreatGRPCStatus(Status status) {
switch (status) {
case SUCCESS:
return grpc::Status::OK;
case FAILED:
return grpc::Status::CANCELLED;
case INVALID_INPUTS:
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "The Predict Inputs do not match the Model Request!");
default:
return grpc::Status::CANCELLED;
}
}
} // namespace } // namespace
// Service Implement // Service Implement
@ -134,8 +145,8 @@ class MSServiceImpl final : public MSService::Service {
MSI_TIME_STAMP_START(Predict) MSI_TIME_STAMP_START(Predict)
auto res = Session::Instance().Predict(*request, *reply); auto res = Session::Instance().Predict(*request, *reply);
MSI_TIME_STAMP_END(Predict) MSI_TIME_STAMP_END(Predict)
if (res != SUCCESS) { if (res != inference::SUCCESS) {
return grpc::Status::CANCELLED; return CreatGRPCStatus(res);
} }
MSI_LOG(INFO) << "Finish call service Eval"; MSI_LOG(INFO) << "Finish call service Eval";
return grpc::Status::OK; return grpc::Status::OK;

View File

@ -18,7 +18,7 @@
namespace mindspore { namespace mindspore {
namespace serving { namespace serving {
using Status = uint32_t; using Status = uint32_t;
enum ServingStatus { SUCCESS = 0, FAILED }; enum ServingStatus { SUCCESS = 0, FAILED, INVALID_INPUTS };
} // namespace serving } // namespace serving
} // namespace mindspore } // namespace mindspore

View File

@ -31,51 +31,51 @@ using ms_serving::TensorShape;
class MSClient { class MSClient {
public: public:
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {} explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
~MSClient() = default; ~MSClient() = default;
std::string Predict() { std::string Predict() {
// Data we are sending to the server. // Data we are sending to the server.
PredictRequest request; PredictRequest request;
Tensor data; Tensor data;
TensorShape shape; TensorShape shape;
shape.add_dims(4); shape.add_dims(4);
*data.mutable_tensor_shape() = shape; *data.mutable_tensor_shape() = shape;
data.set_tensor_type(ms_serving::MS_FLOAT32); data.set_tensor_type(ms_serving::MS_FLOAT32);
std::vector<float> input_data{1, 2, 3, 4}; std::vector<float> input_data{1, 2, 3, 4};
data.set_data(input_data.data(), input_data.size() * sizeof(float)); data.set_data(input_data.data(), input_data.size() * sizeof(float));
*request.add_data() = data; *request.add_data() = data;
*request.add_data() = data; *request.add_data() = data;
std::cout << "intput tensor size is " << request.data_size() << std::endl; std::cout << "intput tensor size is " << request.data_size() << std::endl;
// Container for the data we expect from the server. // Container for the data we expect from the server.
PredictReply reply; PredictReply reply;
// Context for the client. It could be used to convey extra information to // Context for the client. It could be used to convey extra information to
// the server and/or tweak certain RPC behaviors. // the server and/or tweak certain RPC behaviors.
ClientContext context; ClientContext context;
// The actual RPC. // The actual RPC.
Status status = stub_->Predict(&context, request, &reply); Status status = stub_->Predict(&context, request, &reply);
std::cout << "Compute [1, 2, 3, 4] + [1, 2, 3, 4]" << std::endl; std::cout << "Compute [1, 2, 3, 4] + [1, 2, 3, 4]" << std::endl;
// Act upon its status.
if (status.ok()) {
std::cout << "Add result is"; std::cout << "Add result is";
for (size_t i = 0; i < reply.result(0).data().size() / sizeof(float); i++) { for (size_t i = 0; i < reply.result(0).data().size() / sizeof(float); i++) {
std::cout << " " << (reinterpret_cast<const float *>(reply.mutable_result(0)->mutable_data()->data()))[i]; std::cout << " " << (reinterpret_cast<const float *>(reply.mutable_result(0)->mutable_data()->data()))[i];
} }
std::cout << std::endl; std::cout << std::endl;
return "RPC OK";
// Act upon its status. } else {
if (status.ok()) { std::cout << status.error_code() << ": " << status.error_message() << std::endl;
return "RPC OK"; return "RPC failed";
} else {
std::cout << status.error_code() << ": " << status.error_message() << std::endl;
return "RPC failed";
}
} }
}
private: private:
std::unique_ptr<MSService::Stub> stub_; std::unique_ptr<MSService::Stub> stub_;
}; };
int main(int argc, char **argv) { int main(int argc, char **argv) {

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import sys
import grpc import grpc
import numpy as np import numpy as np
import ms_service_pb2 import ms_service_pb2
@ -19,7 +20,19 @@ import ms_service_pb2_grpc
def run(): def run():
channel = grpc.insecure_channel('localhost:5500') if len(sys.argv) > 2:
sys.exit("input error")
channel_str = ""
if len(sys.argv) == 2:
split_args = sys.argv[1].split('=')
if len(split_args) > 1:
channel_str = split_args[1]
else:
channel_str = 'localhost:5500'
else:
channel_str = 'localhost:5500'
channel = grpc.insecure_channel(channel_str)
stub = ms_service_pb2_grpc.MSServiceStub(channel) stub = ms_service_pb2_grpc.MSServiceStub(channel)
request = ms_service_pb2.PredictRequest() request = ms_service_pb2.PredictRequest()
@ -33,11 +46,17 @@ def run():
y.tensor_type = ms_service_pb2.MS_FLOAT32 y.tensor_type = ms_service_pb2.MS_FLOAT32
y.data = (np.ones([4]).astype(np.float32)).tobytes() y.data = (np.ones([4]).astype(np.float32)).tobytes()
result = stub.Predict(request) try:
print(result) result = stub.Predict(request)
result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims) print(result)
print("ms client received: ") result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims)
print(result_np) print("ms client received: ")
print(result_np)
except grpc.RpcError as e:
print(e.details())
status_code = e.code()
print(status_code.name)
print(status_code.value)
if __name__ == '__main__': if __name__ == '__main__':
run() run()