forked from OSSInnovation/mindspore
fix serving bugs
This commit is contained in:
parent
df795daf75
commit
52c58735fc
|
@ -277,10 +277,11 @@ endif ()
|
||||||
|
|
||||||
if (USE_GLOG)
|
if (USE_GLOG)
|
||||||
target_link_libraries(inference PRIVATE mindspore::glog)
|
target_link_libraries(inference PRIVATE mindspore::glog)
|
||||||
else()
|
|
||||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
|
||||||
target_link_options(inference PRIVATE -Wl,-init,mindspore_log_init)
|
|
||||||
elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
|
||||||
set_target_properties(inference PROPERTIES MACOSX_RPATH ON)
|
|
||||||
endif ()
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||||
|
target_link_options(inference PRIVATE -Wl,-init,common_log_init)
|
||||||
|
elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||||
|
set_target_properties(inference PROPERTIES MACOSX_RPATH ON)
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
|
|
@ -33,9 +33,14 @@
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
namespace mindspore::inference {
|
namespace mindspore::inference {
|
||||||
std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const std::string &device) {
|
std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const std::string &device) {
|
||||||
|
try {
|
||||||
inference::Session::RegAllOp();
|
inference::Session::RegAllOp();
|
||||||
auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size);
|
auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size);
|
||||||
return anf_graph;
|
return anf_graph;
|
||||||
|
} catch (std::exception &e) {
|
||||||
|
MS_LOG(ERROR) << "Inference LoadModel failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExitInference() {
|
void ExitInference() {
|
||||||
|
@ -51,12 +56,17 @@ void ExitInference() {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<MSSession> MSSession::CreateSession(const std::string &device, uint32_t device_id) {
|
std::shared_ptr<MSSession> MSSession::CreateSession(const std::string &device, uint32_t device_id) {
|
||||||
|
try {
|
||||||
auto session = std::make_shared<inference::Session>();
|
auto session = std::make_shared<inference::Session>();
|
||||||
auto ret = session->Init(device, device_id);
|
auto ret = session->Init(device, device_id);
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return session;
|
return session;
|
||||||
|
} catch (std::exception &e) {
|
||||||
|
MS_LOG(ERROR) << "Inference CreatSession failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Session::RegAllOp() {
|
void Session::RegAllOp() {
|
||||||
|
@ -113,12 +123,18 @@ void Session::RegAllOp() {
|
||||||
|
|
||||||
uint32_t Session::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) {
|
uint32_t Session::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) {
|
||||||
MS_ASSERT(session_impl_ != nullptr);
|
MS_ASSERT(session_impl_ != nullptr);
|
||||||
|
try {
|
||||||
auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
|
auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
|
||||||
py::gil_scoped_release gil_release;
|
py::gil_scoped_release gil_release;
|
||||||
return graph_id;
|
return graph_id;
|
||||||
|
} catch (std::exception &e) {
|
||||||
|
MS_LOG(ERROR) << "Inference CompileGraph failed";
|
||||||
|
return static_cast<uint32_t>(-1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) {
|
MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) {
|
||||||
|
try {
|
||||||
std::vector<tensor::TensorPtr> inTensors;
|
std::vector<tensor::TensorPtr> inTensors;
|
||||||
inTensors.resize(inputs.size());
|
inTensors.resize(inputs.size());
|
||||||
bool has_error = false;
|
bool has_error = false;
|
||||||
|
@ -146,14 +162,32 @@ MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_p
|
||||||
session_impl_->RunGraph(graph_id, inTensors, &outputs);
|
session_impl_->RunGraph(graph_id, inTensors, &outputs);
|
||||||
|
|
||||||
return TransformVectorRefToMultiTensor(outputs);
|
return TransformVectorRefToMultiTensor(outputs);
|
||||||
|
} catch (std::exception &e) {
|
||||||
|
MS_LOG(ERROR) << "Inference Rungraph failed";
|
||||||
|
return MultiTensor();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
namespace {
|
||||||
|
string AjustTargetName(const std::string &device) {
|
||||||
|
if (device == kAscendDevice) {
|
||||||
|
return std::string(kAscendDevice) + "Inference";
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "Only support device Ascend right now";
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
int Session::Init(const std::string &device, uint32_t device_id) {
|
int Session::Init(const std::string &device, uint32_t device_id) {
|
||||||
RegAllOp();
|
RegAllOp();
|
||||||
auto ms_context = MsContext::GetInstance();
|
auto ms_context = MsContext::GetInstance();
|
||||||
ms_context->set_execution_mode(kGraphMode);
|
ms_context->set_execution_mode(kGraphMode);
|
||||||
ms_context->set_device_target(kAscendDevice);
|
ms_context->set_device_id(device_id);
|
||||||
session_impl_ = session::SessionFactory::Get().Create(device);
|
auto ajust_device = AjustTargetName(device);
|
||||||
|
if (ajust_device == "") {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
ms_context->set_device_target(device);
|
||||||
|
session_impl_ = session::SessionFactory::Get().Create(ajust_device);
|
||||||
if (session_impl_ == nullptr) {
|
if (session_impl_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available.";
|
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available.";
|
||||||
return -1;
|
return -1;
|
||||||
|
|
|
@ -463,7 +463,7 @@ void InitSubModulesLogLevel() {
|
||||||
|
|
||||||
// set submodule's log level
|
// set submodule's log level
|
||||||
auto submodule = GetEnv("MS_SUBMODULE_LOG_v");
|
auto submodule = GetEnv("MS_SUBMODULE_LOG_v");
|
||||||
MS_LOG(INFO) << "MS_SUBMODULE_LOG_v=`" << submodule << "`";
|
MS_LOG(DEBUG) << "MS_SUBMODULE_LOG_v=`" << submodule << "`";
|
||||||
LogConfigParser parser(submodule);
|
LogConfigParser parser(submodule);
|
||||||
auto configs = parser.Parse();
|
auto configs = parser.Parse();
|
||||||
for (const auto &cfg : configs) {
|
for (const auto &cfg : configs) {
|
||||||
|
@ -489,22 +489,14 @@ void InitSubModulesLogLevel() {
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
// shared lib init hook
|
|
||||||
#if defined(_WIN32) || defined(_WIN64)
|
#if defined(_WIN32) || defined(_WIN64)
|
||||||
__attribute__((constructor)) void mindspore_log_init(void) {
|
__attribute__((constructor)) void common_log_init(void) {
|
||||||
#else
|
#else
|
||||||
void mindspore_log_init(void) {
|
void common_log_init(void) {
|
||||||
#endif
|
#endif
|
||||||
#ifdef USE_GLOG
|
#ifdef USE_GLOG
|
||||||
// do not use glog predefined log prefix
|
// do not use glog predefined log prefix
|
||||||
FLAGS_log_prefix = false;
|
FLAGS_log_prefix = false;
|
||||||
static bool is_glog_initialzed = false;
|
|
||||||
if (!is_glog_initialzed) {
|
|
||||||
#if !defined(_WIN32) && !defined(_WIN64)
|
|
||||||
google::InitGoogleLogging("mindspore");
|
|
||||||
#endif
|
|
||||||
is_glog_initialzed = true;
|
|
||||||
}
|
|
||||||
// set default log level to WARNING
|
// set default log level to WARNING
|
||||||
if (mindspore::GetEnv("GLOG_v").empty()) {
|
if (mindspore::GetEnv("GLOG_v").empty()) {
|
||||||
FLAGS_v = mindspore::WARNING;
|
FLAGS_v = mindspore::WARNING;
|
||||||
|
@ -525,4 +517,22 @@ void mindspore_log_init(void) {
|
||||||
#endif
|
#endif
|
||||||
mindspore::InitSubModulesLogLevel();
|
mindspore::InitSubModulesLogLevel();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// shared lib init hook
|
||||||
|
#if defined(_WIN32) || defined(_WIN64)
|
||||||
|
__attribute__((constructor)) void mindspore_log_init(void) {
|
||||||
|
#else
|
||||||
|
void mindspore_log_init(void) {
|
||||||
|
#endif
|
||||||
|
#ifdef USE_GLOG
|
||||||
|
static bool is_glog_initialzed = false;
|
||||||
|
if (!is_glog_initialzed) {
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
google::InitGoogleLogging("mindspore");
|
||||||
|
#endif
|
||||||
|
is_glog_initialzed = true;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
common_log_init();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <future>
|
||||||
|
|
||||||
#include "mindspore/ccsrc/utils/log_adapter.h"
|
#include "mindspore/ccsrc/utils/log_adapter.h"
|
||||||
#include "serving/ms_service.grpc.pb.h"
|
#include "serving/ms_service.grpc.pb.h"
|
||||||
|
@ -40,7 +41,7 @@ namespace serving {
|
||||||
using MSTensorPtr = std::shared_ptr<inference::MSTensor>;
|
using MSTensorPtr = std::shared_ptr<inference::MSTensor>;
|
||||||
|
|
||||||
Status Session::CreatDeviceSession(const std::string &device, uint32_t device_id) {
|
Status Session::CreatDeviceSession(const std::string &device, uint32_t device_id) {
|
||||||
session_ = inference::MSSession::CreateSession(device + "Inference", device_id);
|
session_ = inference::MSSession::CreateSession(device, device_id);
|
||||||
if (session_ == nullptr) {
|
if (session_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Creat Session Failed";
|
MS_LOG(ERROR) << "Creat Session Failed";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -67,6 +68,7 @@ Status Session::Predict(const std::vector<MSTensorPtr> &inputs, inference::Multi
|
||||||
MS_LOG(INFO) << "run Predict";
|
MS_LOG(INFO) << "run Predict";
|
||||||
|
|
||||||
*outputs = session_->RunGraph(graph_id_, inputs);
|
*outputs = session_->RunGraph(graph_id_, inputs);
|
||||||
|
MS_LOG(INFO) << "run Predict finished";
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,12 +82,16 @@ Status Session::Warmup(const MindSporeModelPtr model) {
|
||||||
std::string file_name = model->GetModelPath() + '/' + model->GetModelName();
|
std::string file_name = model->GetModelPath() + '/' + model->GetModelName();
|
||||||
char *graphBuf = ReadFile(file_name.c_str(), &size);
|
char *graphBuf = ReadFile(file_name.c_str(), &size);
|
||||||
if (graphBuf == nullptr) {
|
if (graphBuf == nullptr) {
|
||||||
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
|
MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str();
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
last_graph_ = inference::LoadModel(graphBuf, size, device_type_);
|
last_graph_ = inference::LoadModel(graphBuf, size, device_type_);
|
||||||
|
if (last_graph_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
graph_id_ = session_->CompileGraph(last_graph_);
|
graph_id_ = session_->CompileGraph(last_graph_);
|
||||||
MS_LOG(INFO) << "Session Warmup";
|
MS_LOG(INFO) << "Session Warmup finished";
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,6 +101,9 @@ Status Session::Clear() {
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
static const uint32_t uint32max = 0x7FFFFFFF;
|
||||||
|
std::promise<void> exit_requested;
|
||||||
|
|
||||||
const std::map<ms_serving::DataType, TypeId> type2id_map{
|
const std::map<ms_serving::DataType, TypeId> type2id_map{
|
||||||
{ms_serving::MS_UNKNOWN, TypeId::kNumberTypeBegin}, {ms_serving::MS_BOOL, TypeId::kNumberTypeBool},
|
{ms_serving::MS_UNKNOWN, TypeId::kNumberTypeBegin}, {ms_serving::MS_BOOL, TypeId::kNumberTypeBool},
|
||||||
{ms_serving::MS_INT8, TypeId::kNumberTypeInt8}, {ms_serving::MS_UINT8, TypeId::kNumberTypeUInt8},
|
{ms_serving::MS_INT8, TypeId::kNumberTypeInt8}, {ms_serving::MS_UINT8, TypeId::kNumberTypeUInt8},
|
||||||
|
@ -141,7 +150,7 @@ MSTensorPtr ServingTensor2MSTensor(const ms_serving::Tensor &tensor) {
|
||||||
}
|
}
|
||||||
TypeId type = iter->second;
|
TypeId type = iter->second;
|
||||||
auto ms_tensor = std::shared_ptr<inference::MSTensor>(inference::MSTensor::CreateTensor(type, shape));
|
auto ms_tensor = std::shared_ptr<inference::MSTensor>(inference::MSTensor::CreateTensor(type, shape));
|
||||||
memcpy_s(ms_tensor->MutableData(), tensor.data().size(), tensor.data().data(), tensor.data().size());
|
memcpy_s(ms_tensor->MutableData(), ms_tensor->Size(), tensor.data().data(), tensor.data().size());
|
||||||
return ms_tensor;
|
return ms_tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -166,10 +175,7 @@ void ClearEnv() {
|
||||||
Session::Instance().Clear();
|
Session::Instance().Clear();
|
||||||
inference::ExitInference();
|
inference::ExitInference();
|
||||||
}
|
}
|
||||||
void HandleSignal(int sig) {
|
void HandleSignal(int sig) { exit_requested.set_value(); }
|
||||||
ClearEnv();
|
|
||||||
exit(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef ENABLE_D
|
#ifdef ENABLE_D
|
||||||
static rtContext_t g_ctx = nullptr;
|
static rtContext_t g_ctx = nullptr;
|
||||||
|
@ -247,6 +253,7 @@ Status Server::BuildAndStart() {
|
||||||
rtError_t rt_ret = rtCtxGetCurrent(&ctx);
|
rtError_t rt_ret = rtCtxGetCurrent(&ctx);
|
||||||
if (rt_ret != RT_ERROR_NONE || ctx == nullptr) {
|
if (rt_ret != RT_ERROR_NONE || ctx == nullptr) {
|
||||||
MS_LOG(ERROR) << "the ascend device context is null";
|
MS_LOG(ERROR) << "the ascend device context is null";
|
||||||
|
ClearEnv();
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
g_ctx = ctx;
|
g_ctx = ctx;
|
||||||
|
@ -258,6 +265,7 @@ Status Server::BuildAndStart() {
|
||||||
auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0);
|
auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0);
|
||||||
grpc::ServerBuilder builder;
|
grpc::ServerBuilder builder;
|
||||||
builder.SetOption(std::move(option));
|
builder.SetOption(std::move(option));
|
||||||
|
builder.SetMaxMessageSize(uint32max);
|
||||||
// Listen on the given address without any authentication mechanism.
|
// Listen on the given address without any authentication mechanism.
|
||||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||||
// Register "service" as the instance through which we'll communicate with
|
// Register "service" as the instance through which we'll communicate with
|
||||||
|
@ -265,13 +273,15 @@ Status Server::BuildAndStart() {
|
||||||
builder.RegisterService(&service);
|
builder.RegisterService(&service);
|
||||||
// Finally assemble the server.
|
// Finally assemble the server.
|
||||||
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
|
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
|
||||||
|
auto grpc_server_run = [&server]() { server->Wait(); };
|
||||||
|
std::thread serving_thread(grpc_server_run);
|
||||||
MS_LOG(INFO) << "Server listening on " << server_address << std::endl;
|
MS_LOG(INFO) << "Server listening on " << server_address << std::endl;
|
||||||
|
auto exit_future = exit_requested.get_future();
|
||||||
// Wait for the server to shutdown. Note that some other thread must be
|
exit_future.wait();
|
||||||
// responsible for shutting down the server for this call to ever return.
|
ClearEnv();
|
||||||
server->Wait();
|
server->Shutdown();
|
||||||
|
serving_thread.join();
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace serving
|
} // namespace serving
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -29,7 +29,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace serving {
|
namespace serving {
|
||||||
|
|
||||||
char *ReadFile(const char *file, size_t *size) {
|
char *ReadFile(const char *file, size_t *size) {
|
||||||
if (file == nullptr) {
|
if (file == nullptr) {
|
||||||
MS_LOG(ERROR) << "file is nullptr";
|
MS_LOG(ERROR) << "file is nullptr";
|
||||||
|
@ -70,8 +69,8 @@ bool DirOrFileExist(const std::string &file_path) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> GetAllSubDirs(const std::string &dir_path) {
|
std::vector<std::string> GetAllSubDirs(const std::string &dir_path) {
|
||||||
DIR *dir;
|
DIR *dir = nullptr;
|
||||||
struct dirent *ptr;
|
struct dirent *ptr = nullptr;
|
||||||
std::vector<std::string> SubDirs;
|
std::vector<std::string> SubDirs;
|
||||||
|
|
||||||
if ((dir = opendir(dir_path.c_str())) == NULL) {
|
if ((dir = opendir(dir_path.c_str())) == NULL) {
|
||||||
|
|
|
@ -36,17 +36,16 @@ bool RemovePrefix(std::string *str, const std::string &prefix) {
|
||||||
|
|
||||||
bool Option::ParseInt32(std::string *arg) {
|
bool Option::ParseInt32(std::string *arg) {
|
||||||
if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) {
|
if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) {
|
||||||
char extra;
|
|
||||||
int32_t parsed_value;
|
int32_t parsed_value;
|
||||||
if (sscanf(arg->data(), "%d%c", &parsed_value, &extra) != 1) {
|
try {
|
||||||
std::cout << "Parse " << name_ << "Error for option " << *arg << std::endl;
|
parsed_value = std::stoi(arg->data());
|
||||||
|
} catch (std::invalid_argument) {
|
||||||
|
std::cout << "Parse " << name_ << " Error for option " << *arg << std::endl;
|
||||||
return false;
|
return false;
|
||||||
} else {
|
|
||||||
*int32_default_ = parsed_value;
|
|
||||||
}
|
}
|
||||||
|
*int32_default_ = parsed_value;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,17 +75,16 @@ bool Option::ParseString(std::string *arg) {
|
||||||
|
|
||||||
bool Option::ParseFloat(std::string *arg) {
|
bool Option::ParseFloat(std::string *arg) {
|
||||||
if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) {
|
if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) {
|
||||||
char extra;
|
|
||||||
float parsed_value;
|
float parsed_value;
|
||||||
if (sscanf(arg->data(), "%f%c", &parsed_value, &extra) != 1) {
|
try {
|
||||||
std::cout << "Parse " << name_ << "Error for option " << *arg << std::endl;
|
parsed_value = std::stof(arg->data());
|
||||||
|
} catch (std::invalid_argument) {
|
||||||
|
std::cout << "Parse " << name_ << " Error for option " << *arg << std::endl;
|
||||||
return false;
|
return false;
|
||||||
} else {
|
|
||||||
*float_default_ = parsed_value;
|
|
||||||
}
|
}
|
||||||
|
*float_default_ = parsed_value;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -159,10 +157,11 @@ Options::Options() : args_(nullptr) { CreateOptions(); }
|
||||||
void Options::CreateOptions() {
|
void Options::CreateOptions() {
|
||||||
args_ = std::make_shared<Arguments>();
|
args_ = std::make_shared<Arguments>();
|
||||||
std::vector<Option> options = {
|
std::vector<Option> options = {
|
||||||
Option("port", &args_->grpc_port, "Port to listen on for gRPC API, default is 5500"),
|
Option("port", &args_->grpc_port,
|
||||||
Option("model_name", &args_->model_name, "model name "),
|
"[Optional] Port to listen on for gRPC API, default is 5500, range from 1 to 65535"),
|
||||||
Option("model_path", &args_->model_path, "the path of the model files"),
|
Option("model_name", &args_->model_name, "[Required] model name "),
|
||||||
Option("device_id", &args_->device_id, "the device id, default is 0"),
|
Option("model_path", &args_->model_path, "[Required] the path of the model files"),
|
||||||
|
Option("device_id", &args_->device_id, "[Optional] the device id, default is 0, range from 0 to 7"),
|
||||||
};
|
};
|
||||||
options_ = options;
|
options_ = options;
|
||||||
}
|
}
|
||||||
|
@ -176,6 +175,14 @@ bool Options::CheckOptions() {
|
||||||
std::cout << "device_type only support Ascend right now" << std::endl;
|
std::cout << "device_type only support Ascend right now" << std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (args_->device_id > 7) {
|
||||||
|
std::cout << "the device_id should be in [0~7]" << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (args_->grpc_port < 1 || args_->grpc_port > 65535) {
|
||||||
|
std::cout << "the port should be in [1~65535]" << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -238,6 +245,5 @@ void Options::Usage() {
|
||||||
<< option.usage_ << std::endl;
|
<< option.usage_ << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace serving
|
} // namespace serving
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -22,7 +22,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace serving {
|
namespace serving {
|
||||||
|
|
||||||
struct Arguments {
|
struct Arguments {
|
||||||
int32_t grpc_port = 5500;
|
int32_t grpc_port = 5500;
|
||||||
std::string grpc_socket_path;
|
std::string grpc_socket_path;
|
||||||
|
@ -40,6 +39,7 @@ class Option {
|
||||||
Option(const std::string &name, bool *default_point, const std::string &usage);
|
Option(const std::string &name, bool *default_point, const std::string &usage);
|
||||||
Option(const std::string &name, std::string *default_point, const std::string &usage);
|
Option(const std::string &name, std::string *default_point, const std::string &usage);
|
||||||
Option(const std::string &name, float *default_point, const std::string &usage);
|
Option(const std::string &name, float *default_point, const std::string &usage);
|
||||||
|
~Option() = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Options;
|
friend class Options;
|
||||||
|
@ -77,7 +77,6 @@ class Options {
|
||||||
std::vector<Option> options_;
|
std::vector<Option> options_;
|
||||||
std::shared_ptr<Arguments> args_;
|
std::shared_ptr<Arguments> args_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace serving
|
} // namespace serving
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace serving {
|
namespace serving {
|
||||||
|
|
||||||
MindSporeModel::MindSporeModel(const std::string &model_name, const std::string &model_path,
|
MindSporeModel::MindSporeModel(const std::string &model_name, const std::string &model_path,
|
||||||
const std::string &model_version, const time_t &last_update_time)
|
const std::string &model_version, const time_t &last_update_time)
|
||||||
: model_name_(model_name),
|
: model_name_(model_name),
|
||||||
|
|
|
@ -25,7 +25,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace serving {
|
namespace serving {
|
||||||
|
|
||||||
volatile bool stop_poll = false;
|
volatile bool stop_poll = false;
|
||||||
|
|
||||||
std::string GetVersionFromPath(const std::string &path) {
|
std::string GetVersionFromPath(const std::string &path) {
|
||||||
|
@ -102,10 +101,10 @@ Status VersionController::CreateInitModels() {
|
||||||
}
|
}
|
||||||
std::vector<std::string> SubDirs = GetAllSubDirs(models_path_);
|
std::vector<std::string> SubDirs = GetAllSubDirs(models_path_);
|
||||||
if (version_control_strategy_ == kLastest) {
|
if (version_control_strategy_ == kLastest) {
|
||||||
auto path = SubDirs.empty() ? models_path_ : SubDirs.back();
|
std::string model_version = GetVersionFromPath(models_path_);
|
||||||
std::string model_version = GetVersionFromPath(path);
|
time_t last_update_time = GetModifyTime(models_path_);
|
||||||
time_t last_update_time = GetModifyTime(path);
|
MindSporeModelPtr model_ptr =
|
||||||
MindSporeModelPtr model_ptr = std::make_shared<MindSporeModel>(model_name_, path, model_version, last_update_time);
|
std::make_shared<MindSporeModel>(model_name_, models_path_, model_version, last_update_time);
|
||||||
valid_models_.emplace_back(model_ptr);
|
valid_models_.emplace_back(model_ptr);
|
||||||
} else {
|
} else {
|
||||||
for (auto &dir : SubDirs) {
|
for (auto &dir : SubDirs) {
|
||||||
|
@ -119,8 +118,8 @@ Status VersionController::CreateInitModels() {
|
||||||
MS_LOG(ERROR) << "There is no valid model for serving";
|
MS_LOG(ERROR) << "There is no valid model for serving";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
Session::Instance().Warmup(valid_models_.back());
|
auto ret = Session::Instance().Warmup(valid_models_.back());
|
||||||
return SUCCESS;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
void VersionController::StartPollModelPeriodic() {
|
void VersionController::StartPollModelPeriodic() {
|
||||||
|
@ -129,6 +128,5 @@ void VersionController::StartPollModelPeriodic() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void VersionController::StopPollModelPeriodic() {}
|
void VersionController::StopPollModelPeriodic() {}
|
||||||
|
|
||||||
} // namespace serving
|
} // namespace serving
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -64,7 +64,6 @@ class PeriodicFunction {
|
||||||
VersionController::VersionControllerStrategy version_control_strategy_;
|
VersionController::VersionControllerStrategy version_control_strategy_;
|
||||||
std::vector<MindSporeModelPtr> valid_models_;
|
std::vector<MindSporeModelPtr> valid_models_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace serving
|
} // namespace serving
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -214,6 +214,7 @@ PredictRequest ReadBertInput() {
|
||||||
class MSClient {
|
class MSClient {
|
||||||
public:
|
public:
|
||||||
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
|
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
|
||||||
|
~MSClient() = default;
|
||||||
|
|
||||||
std::string Predict(const std::string &type) {
|
std::string Predict(const std::string &type) {
|
||||||
// Data we are sending to the server.
|
// Data we are sending to the server.
|
||||||
|
@ -310,7 +311,6 @@ int main(int argc, char **argv) {
|
||||||
type = "add";
|
type = "add";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
target_str = "localhost:5500";
|
target_str = "localhost:5500";
|
||||||
type = "add";
|
type = "add";
|
||||||
|
|
Loading…
Reference in New Issue