forked from mindspore-Ecosystem/mindspore
!3794 modify inference return type
Merge pull request !3794 from dinghao/master
This commit is contained in:
commit
2b56562770
|
@ -41,6 +41,7 @@ cmake-build-debug
|
|||
*.pb.h
|
||||
*.pb.cc
|
||||
*.pb
|
||||
*_grpc.py
|
||||
|
||||
# Object files
|
||||
*.o
|
||||
|
|
|
@ -24,20 +24,20 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace inference {
|
||||
|
||||
enum Status { SUCCESS = 0, FAILED, INVALID_INPUTS };
|
||||
class MS_API InferSession {
|
||||
public:
|
||||
InferSession() = default;
|
||||
virtual ~InferSession() = default;
|
||||
virtual bool InitEnv(const std::string &device_type, uint32_t device_id) = 0;
|
||||
virtual bool FinalizeEnv() = 0;
|
||||
virtual bool LoadModelFromFile(const std::string &file_name, uint32_t &model_id) = 0;
|
||||
virtual bool UnloadModel(uint32_t model_id) = 0;
|
||||
virtual Status InitEnv(const std::string &device_type, uint32_t device_id) = 0;
|
||||
virtual Status FinalizeEnv() = 0;
|
||||
virtual Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) = 0;
|
||||
virtual Status UnloadModel(uint32_t model_id) = 0;
|
||||
// override this method to avoid request/reply data copy
|
||||
virtual bool ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) = 0;
|
||||
virtual Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) = 0;
|
||||
|
||||
virtual bool ExecuteModel(uint32_t model_id, const std::vector<InferTensor> &inputs,
|
||||
std::vector<InferTensor> &outputs) {
|
||||
virtual Status ExecuteModel(uint32_t model_id, const std::vector<InferTensor> &inputs,
|
||||
std::vector<InferTensor> &outputs) {
|
||||
VectorInferTensorWrapRequest request(inputs);
|
||||
VectorInferTensorWrapReply reply(outputs);
|
||||
return ExecuteModel(model_id, request, reply);
|
||||
|
|
|
@ -37,8 +37,8 @@ namespace mindspore::inference {
|
|||
std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &device, uint32_t device_id) {
|
||||
try {
|
||||
auto session = std::make_shared<MSInferSession>();
|
||||
bool ret = session->InitEnv(device, device_id);
|
||||
if (!ret) {
|
||||
Status ret = session->InitEnv(device, device_id);
|
||||
if (ret != SUCCESS) {
|
||||
return nullptr;
|
||||
}
|
||||
return session;
|
||||
|
@ -84,21 +84,21 @@ std::shared_ptr<std::vector<char>> MSInferSession::ReadFile(const std::string &f
|
|||
return buf;
|
||||
}
|
||||
|
||||
bool MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
|
||||
Status MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
|
||||
auto graphBuf = ReadFile(file_name);
|
||||
if (graphBuf == nullptr) {
|
||||
MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str();
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
auto graph = LoadModel(graphBuf->data(), graphBuf->size(), device_type_);
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
bool ret = CompileGraph(graph, model_id);
|
||||
if (!ret) {
|
||||
Status ret = CompileGraph(graph, model_id);
|
||||
if (ret != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Compile graph model failed, file name is " << file_name.c_str();
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Load model from file " << file_name << " success";
|
||||
|
||||
|
@ -107,14 +107,14 @@ bool MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &m
|
|||
rtError_t rt_ret = rtCtxGetCurrent(&context_);
|
||||
if (rt_ret != RT_ERROR_NONE || context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "the ascend device context is null";
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
#endif
|
||||
|
||||
return true;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
bool MSInferSession::UnloadModel(uint32_t model_id) { return true; }
|
||||
Status MSInferSession::UnloadModel(uint32_t model_id) { return SUCCESS; }
|
||||
|
||||
tensor::TensorPtr ServingTensor2MSTensor(const InferTensorBase &out_tensor) {
|
||||
std::vector<int> shape;
|
||||
|
@ -170,16 +170,16 @@ void MSTensor2ServingTensor(tensor::TensorPtr ms_tensor, InferTensorBase &out_te
|
|||
out_tensor.set_data(ms_tensor->data_c(), ms_tensor->Size());
|
||||
}
|
||||
|
||||
bool MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) {
|
||||
Status MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) {
|
||||
#ifdef ENABLE_D
|
||||
if (context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "rtCtx is nullptr";
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
rtError_t rt_ret = rtCtxSetCurrent(context_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "set Ascend rtCtx failed";
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -187,47 +187,47 @@ bool MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request,
|
|||
for (size_t i = 0; i < request.size(); i++) {
|
||||
if (request[i] == nullptr) {
|
||||
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed, input tensor is null, index " << i;
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
auto input = ServingTensor2MSTensor(*request[i]);
|
||||
if (input == nullptr) {
|
||||
MS_LOG(ERROR) << "Tensor convert failed";
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
inputs.push_back(input);
|
||||
}
|
||||
if (!CheckModelInputs(model_id, inputs)) {
|
||||
MS_LOG(ERROR) << "Check Model " << model_id << " Inputs Failed";
|
||||
return false;
|
||||
return INVALID_INPUTS;
|
||||
}
|
||||
vector<tensor::TensorPtr> outputs = RunGraph(model_id, inputs);
|
||||
if (outputs.empty()) {
|
||||
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed";
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
reply.clear();
|
||||
for (const auto &tensor : outputs) {
|
||||
auto out_tensor = reply.add();
|
||||
if (out_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed, add output tensor failed";
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
MSTensor2ServingTensor(tensor, *out_tensor);
|
||||
}
|
||||
return true;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
bool MSInferSession::FinalizeEnv() {
|
||||
Status MSInferSession::FinalizeEnv() {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
if (ms_context == nullptr) {
|
||||
MS_LOG(ERROR) << "Get Context failed!";
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
if (!ms_context->CloseTsd()) {
|
||||
MS_LOG(ERROR) << "Inference CloseTsd failed!";
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
return true;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::shared_ptr<FuncGraph> MSInferSession::LoadModel(const char *model_buf, size_t size, const std::string &device) {
|
||||
|
@ -292,16 +292,16 @@ void MSInferSession::RegAllOp() {
|
|||
return;
|
||||
}
|
||||
|
||||
bool MSInferSession::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id) {
|
||||
Status MSInferSession::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id) {
|
||||
MS_ASSERT(session_impl_ != nullptr);
|
||||
try {
|
||||
auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
|
||||
py::gil_scoped_release gil_release;
|
||||
model_id = graph_id;
|
||||
return true;
|
||||
return SUCCESS;
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Inference CompileGraph failed";
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -327,31 +327,31 @@ string MSInferSession::AjustTargetName(const std::string &device) {
|
|||
}
|
||||
}
|
||||
|
||||
bool MSInferSession::InitEnv(const std::string &device, uint32_t device_id) {
|
||||
Status MSInferSession::InitEnv(const std::string &device, uint32_t device_id) {
|
||||
RegAllOp();
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
ms_context->set_execution_mode(kGraphMode);
|
||||
ms_context->set_device_id(device_id);
|
||||
auto ajust_device = AjustTargetName(device);
|
||||
if (ajust_device == "") {
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
ms_context->set_device_target(device);
|
||||
session_impl_ = session::SessionFactory::Get().Create(ajust_device);
|
||||
if (session_impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available.";
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
session_impl_->Init(device_id);
|
||||
if (ms_context == nullptr) {
|
||||
MS_LOG(ERROR) << "Get Context failed!";
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
if (!ms_context->OpenTsd()) {
|
||||
MS_LOG(ERROR) << "Session init OpenTsd failed!";
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
return true;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
bool MSInferSession::CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const {
|
||||
|
|
|
@ -38,11 +38,11 @@ class MSInferSession : public InferSession {
|
|||
MSInferSession();
|
||||
~MSInferSession();
|
||||
|
||||
bool InitEnv(const std::string &device_type, uint32_t device_id) override;
|
||||
bool FinalizeEnv() override;
|
||||
bool LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override;
|
||||
bool UnloadModel(uint32_t model_id) override;
|
||||
bool ExecuteModel(uint32_t model_id, const RequestBase &inputs, ReplyBase &outputs) override;
|
||||
Status InitEnv(const std::string &device_type, uint32_t device_id) override;
|
||||
Status FinalizeEnv() override;
|
||||
Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override;
|
||||
Status UnloadModel(uint32_t model_id) override;
|
||||
Status ExecuteModel(uint32_t model_id, const RequestBase &inputs, ReplyBase &outputs) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<session::SessionBasic> session_impl_ = nullptr;
|
||||
|
@ -57,7 +57,7 @@ class MSInferSession : public InferSession {
|
|||
std::shared_ptr<std::vector<char>> ReadFile(const std::string &file);
|
||||
static void RegAllOp();
|
||||
string AjustTargetName(const std::string &device);
|
||||
bool CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id);
|
||||
Status CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id);
|
||||
bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const;
|
||||
std::vector<tensor::TensorPtr> RunGraph(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs);
|
||||
};
|
||||
|
|
|
@ -35,53 +35,53 @@ std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &dev
|
|||
}
|
||||
}
|
||||
|
||||
bool AclSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
|
||||
return model_process_.LoadModelFromFile(file_name, model_id);
|
||||
Status AclSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
|
||||
return model_process_.LoadModelFromFile(file_name, model_id) ? SUCCESS : FAILED;
|
||||
}
|
||||
|
||||
bool AclSession::UnloadModel(uint32_t model_id) {
|
||||
Status AclSession::UnloadModel(uint32_t model_id) {
|
||||
model_process_.UnLoad();
|
||||
return true;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
bool AclSession::ExecuteModel(uint32_t model_id, const RequestBase &request,
|
||||
ReplyBase &reply) { // set d context
|
||||
Status AclSession::ExecuteModel(uint32_t model_id, const RequestBase &request,
|
||||
ReplyBase &reply) { // set d context
|
||||
aclError rt_ret = aclrtSetCurrentContext(context_);
|
||||
if (rt_ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "set the ascend device context failed";
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
return model_process_.Execute(request, reply);
|
||||
return model_process_.Execute(request, reply) ? SUCCESS : FAILED;
|
||||
}
|
||||
|
||||
bool AclSession::InitEnv(const std::string &device_type, uint32_t device_id) {
|
||||
Status AclSession::InitEnv(const std::string &device_type, uint32_t device_id) {
|
||||
device_type_ = device_type;
|
||||
device_id_ = device_id;
|
||||
auto ret = aclInit(nullptr);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "Execute aclInit Failed";
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
MSI_LOG_INFO << "acl init success";
|
||||
|
||||
ret = aclrtSetDevice(device_id_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "acl open device " << device_id_ << " failed";
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
MSI_LOG_INFO << "open device " << device_id_ << " success";
|
||||
|
||||
ret = aclrtCreateContext(&context_, device_id_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "acl create context failed";
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
MSI_LOG_INFO << "create context success";
|
||||
|
||||
ret = aclrtCreateStream(&stream_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "acl create stream failed";
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
MSI_LOG_INFO << "create stream success";
|
||||
|
||||
|
@ -89,17 +89,17 @@ bool AclSession::InitEnv(const std::string &device_type, uint32_t device_id) {
|
|||
ret = aclrtGetRunMode(&run_mode);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "acl get run mode failed";
|
||||
return false;
|
||||
return FAILED;
|
||||
}
|
||||
bool is_device = (run_mode == ACL_DEVICE);
|
||||
model_process_.SetIsDevice(is_device);
|
||||
MSI_LOG_INFO << "get run mode success is device input/output " << is_device;
|
||||
|
||||
MSI_LOG_INFO << "Init acl success, device id " << device_id_;
|
||||
return true;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
bool AclSession::FinalizeEnv() {
|
||||
Status AclSession::FinalizeEnv() {
|
||||
aclError ret;
|
||||
if (stream_ != nullptr) {
|
||||
ret = aclrtDestroyStream(stream_);
|
||||
|
@ -129,7 +129,7 @@ bool AclSession::FinalizeEnv() {
|
|||
MSI_LOG_ERROR << "finalize acl failed";
|
||||
}
|
||||
MSI_LOG_INFO << "end to finalize acl";
|
||||
return true;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
AclSession::AclSession() = default;
|
||||
|
|
|
@ -32,11 +32,11 @@ class AclSession : public InferSession {
|
|||
public:
|
||||
AclSession();
|
||||
|
||||
bool InitEnv(const std::string &device_type, uint32_t device_id) override;
|
||||
bool FinalizeEnv() override;
|
||||
bool LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override;
|
||||
bool UnloadModel(uint32_t model_id) override;
|
||||
bool ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) override;
|
||||
Status InitEnv(const std::string &device_type, uint32_t device_id) override;
|
||||
Status FinalizeEnv() override;
|
||||
Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override;
|
||||
Status UnloadModel(uint32_t model_id) override;
|
||||
Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) override;
|
||||
|
||||
private:
|
||||
std::string device_type_;
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "core/version_control/version_controller.h"
|
||||
#include "core/util/file_system_operation.h"
|
||||
#include "core/serving_tensor.h"
|
||||
#include "util/status.h"
|
||||
|
||||
using ms_serving::MSService;
|
||||
using ms_serving::PredictReply;
|
||||
|
@ -79,9 +80,9 @@ Status Session::Predict(const PredictRequest &request, PredictReply &reply) {
|
|||
|
||||
auto ret = session_->ExecuteModel(graph_id_, serving_request, serving_reply);
|
||||
MSI_LOG(INFO) << "run Predict finished";
|
||||
if (!ret) {
|
||||
if (Status(ret) != SUCCESS) {
|
||||
MSI_LOG(ERROR) << "execute model return failed";
|
||||
return FAILED;
|
||||
return Status(ret);
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
@ -97,9 +98,9 @@ Status Session::Warmup(const MindSporeModelPtr model) {
|
|||
MSI_TIME_STAMP_START(LoadModelFromFile)
|
||||
auto ret = session_->LoadModelFromFile(file_name, graph_id_);
|
||||
MSI_TIME_STAMP_END(LoadModelFromFile)
|
||||
if (!ret) {
|
||||
if (Status(ret) != SUCCESS) {
|
||||
MSI_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
|
||||
return FAILED;
|
||||
return Status(ret);
|
||||
}
|
||||
model_loaded_ = true;
|
||||
MSI_LOG(INFO) << "Session Warmup finished";
|
||||
|
@ -119,12 +120,22 @@ namespace {
|
|||
static const uint32_t uint32max = 0x7FFFFFFF;
|
||||
std::promise<void> exit_requested;
|
||||
|
||||
void ClearEnv() {
|
||||
Session::Instance().Clear();
|
||||
// inference::ExitInference();
|
||||
}
|
||||
void ClearEnv() { Session::Instance().Clear(); }
|
||||
void HandleSignal(int sig) { exit_requested.set_value(); }
|
||||
|
||||
grpc::Status CreatGRPCStatus(Status status) {
|
||||
switch (status) {
|
||||
case SUCCESS:
|
||||
return grpc::Status::OK;
|
||||
case FAILED:
|
||||
return grpc::Status::CANCELLED;
|
||||
case INVALID_INPUTS:
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "The Predict Inputs do not match the Model Request!");
|
||||
default:
|
||||
return grpc::Status::CANCELLED;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Service Implement
|
||||
|
@ -134,8 +145,8 @@ class MSServiceImpl final : public MSService::Service {
|
|||
MSI_TIME_STAMP_START(Predict)
|
||||
auto res = Session::Instance().Predict(*request, *reply);
|
||||
MSI_TIME_STAMP_END(Predict)
|
||||
if (res != SUCCESS) {
|
||||
return grpc::Status::CANCELLED;
|
||||
if (res != inference::SUCCESS) {
|
||||
return CreatGRPCStatus(res);
|
||||
}
|
||||
MSI_LOG(INFO) << "Finish call service Eval";
|
||||
return grpc::Status::OK;
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
namespace mindspore {
|
||||
namespace serving {
|
||||
using Status = uint32_t;
|
||||
enum ServingStatus { SUCCESS = 0, FAILED };
|
||||
enum ServingStatus { SUCCESS = 0, FAILED, INVALID_INPUTS };
|
||||
} // namespace serving
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -31,51 +31,51 @@ using ms_serving::TensorShape;
|
|||
|
||||
class MSClient {
|
||||
public:
|
||||
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
|
||||
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
|
||||
|
||||
~MSClient() = default;
|
||||
~MSClient() = default;
|
||||
|
||||
std::string Predict() {
|
||||
// Data we are sending to the server.
|
||||
PredictRequest request;
|
||||
std::string Predict() {
|
||||
// Data we are sending to the server.
|
||||
PredictRequest request;
|
||||
|
||||
Tensor data;
|
||||
TensorShape shape;
|
||||
shape.add_dims(4);
|
||||
*data.mutable_tensor_shape() = shape;
|
||||
data.set_tensor_type(ms_serving::MS_FLOAT32);
|
||||
std::vector<float> input_data{1, 2, 3, 4};
|
||||
data.set_data(input_data.data(), input_data.size() * sizeof(float));
|
||||
*request.add_data() = data;
|
||||
*request.add_data() = data;
|
||||
std::cout << "intput tensor size is " << request.data_size() << std::endl;
|
||||
// Container for the data we expect from the server.
|
||||
PredictReply reply;
|
||||
Tensor data;
|
||||
TensorShape shape;
|
||||
shape.add_dims(4);
|
||||
*data.mutable_tensor_shape() = shape;
|
||||
data.set_tensor_type(ms_serving::MS_FLOAT32);
|
||||
std::vector<float> input_data{1, 2, 3, 4};
|
||||
data.set_data(input_data.data(), input_data.size() * sizeof(float));
|
||||
*request.add_data() = data;
|
||||
*request.add_data() = data;
|
||||
std::cout << "intput tensor size is " << request.data_size() << std::endl;
|
||||
// Container for the data we expect from the server.
|
||||
PredictReply reply;
|
||||
|
||||
// Context for the client. It could be used to convey extra information to
|
||||
// the server and/or tweak certain RPC behaviors.
|
||||
ClientContext context;
|
||||
// Context for the client. It could be used to convey extra information to
|
||||
// the server and/or tweak certain RPC behaviors.
|
||||
ClientContext context;
|
||||
|
||||
// The actual RPC.
|
||||
Status status = stub_->Predict(&context, request, &reply);
|
||||
std::cout << "Compute [1, 2, 3, 4] + [1, 2, 3, 4]" << std::endl;
|
||||
// The actual RPC.
|
||||
Status status = stub_->Predict(&context, request, &reply);
|
||||
std::cout << "Compute [1, 2, 3, 4] + [1, 2, 3, 4]" << std::endl;
|
||||
|
||||
// Act upon its status.
|
||||
if (status.ok()) {
|
||||
std::cout << "Add result is";
|
||||
for (size_t i = 0; i < reply.result(0).data().size() / sizeof(float); i++) {
|
||||
std::cout << " " << (reinterpret_cast<const float *>(reply.mutable_result(0)->mutable_data()->data()))[i];
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
// Act upon its status.
|
||||
if (status.ok()) {
|
||||
return "RPC OK";
|
||||
} else {
|
||||
std::cout << status.error_code() << ": " << status.error_message() << std::endl;
|
||||
return "RPC failed";
|
||||
}
|
||||
return "RPC OK";
|
||||
} else {
|
||||
std::cout << status.error_code() << ": " << status.error_message() << std::endl;
|
||||
return "RPC failed";
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<MSService::Stub> stub_;
|
||||
std::unique_ptr<MSService::Stub> stub_;
|
||||
};
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import sys
|
||||
import grpc
|
||||
import numpy as np
|
||||
import ms_service_pb2
|
||||
|
@ -19,7 +20,19 @@ import ms_service_pb2_grpc
|
|||
|
||||
|
||||
def run():
|
||||
channel = grpc.insecure_channel('localhost:5500')
|
||||
if len(sys.argv) > 2:
|
||||
sys.exit("input error")
|
||||
channel_str = ""
|
||||
if len(sys.argv) == 2:
|
||||
split_args = sys.argv[1].split('=')
|
||||
if len(split_args) > 1:
|
||||
channel_str = split_args[1]
|
||||
else:
|
||||
channel_str = 'localhost:5500'
|
||||
else:
|
||||
channel_str = 'localhost:5500'
|
||||
|
||||
channel = grpc.insecure_channel(channel_str)
|
||||
stub = ms_service_pb2_grpc.MSServiceStub(channel)
|
||||
request = ms_service_pb2.PredictRequest()
|
||||
|
||||
|
@ -33,11 +46,17 @@ def run():
|
|||
y.tensor_type = ms_service_pb2.MS_FLOAT32
|
||||
y.data = (np.ones([4]).astype(np.float32)).tobytes()
|
||||
|
||||
result = stub.Predict(request)
|
||||
print(result)
|
||||
result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims)
|
||||
print("ms client received: ")
|
||||
print(result_np)
|
||||
try:
|
||||
result = stub.Predict(request)
|
||||
print(result)
|
||||
result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims)
|
||||
print("ms client received: ")
|
||||
print(result_np)
|
||||
except grpc.RpcError as e:
|
||||
print(e.details())
|
||||
status_code = e.code()
|
||||
print(status_code.name)
|
||||
print(status_code.value)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
|
Loading…
Reference in New Issue