forked from mindspore-Ecosystem/mindspore
init serving
This commit is contained in:
parent
74d12d738d
commit
8d76c708df
|
@ -96,4 +96,8 @@ if (ENABLE_TESTCASES)
|
||||||
add_subdirectory(tests)
|
add_subdirectory(tests)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if (ENABLE_SERVING)
|
||||||
|
add_subdirectory(serving)
|
||||||
|
endif()
|
||||||
|
|
||||||
include(cmake/package.cmake)
|
include(cmake/package.cmake)
|
||||||
|
|
13
build.sh
13
build.sh
|
@ -53,6 +53,7 @@ usage()
|
||||||
echo " -V Specify the minimum required cuda version, default CUDA 9.2"
|
echo " -V Specify the minimum required cuda version, default CUDA 9.2"
|
||||||
echo " -I Compile predict, default off"
|
echo " -I Compile predict, default off"
|
||||||
echo " -K Compile with AKG, default off"
|
echo " -K Compile with AKG, default off"
|
||||||
|
echo " -s Enable serving module, default off"
|
||||||
}
|
}
|
||||||
|
|
||||||
# check value of input is 'on' or 'off'
|
# check value of input is 'on' or 'off'
|
||||||
|
@ -92,9 +93,9 @@ checkopts()
|
||||||
USE_GLOG="on"
|
USE_GLOG="on"
|
||||||
PREDICT_PLATFORM=""
|
PREDICT_PLATFORM=""
|
||||||
ENABLE_AKG="off"
|
ENABLE_AKG="off"
|
||||||
|
ENABLE_SERVING="off"
|
||||||
# Process the options
|
# Process the options
|
||||||
while getopts 'drvj:c:t:hsb:a:g:p:ie:m:I:LRP:Q:D:zM:V:K' opt
|
while getopts 'drvj:c:t:hsb:a:g:p:ie:m:I:LRP:Q:D:zM:V:K:s' opt
|
||||||
do
|
do
|
||||||
OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]')
|
OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]')
|
||||||
case "${opt}" in
|
case "${opt}" in
|
||||||
|
@ -235,6 +236,10 @@ checkopts()
|
||||||
ENABLE_AKG="on"
|
ENABLE_AKG="on"
|
||||||
echo "enable compile with akg"
|
echo "enable compile with akg"
|
||||||
;;
|
;;
|
||||||
|
s)
|
||||||
|
ENABLE_SERVING="on"
|
||||||
|
echo "enable serving"
|
||||||
|
;;
|
||||||
*)
|
*)
|
||||||
echo "Unknown option ${opt}!"
|
echo "Unknown option ${opt}!"
|
||||||
usage
|
usage
|
||||||
|
@ -314,6 +319,10 @@ build_mindspore()
|
||||||
if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then
|
if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then
|
||||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AKG=ON"
|
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AKG=ON"
|
||||||
fi
|
fi
|
||||||
|
if [[ "X$ENABLE_SERVING" = "Xon" ]]; then
|
||||||
|
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_SERVING=ON"
|
||||||
|
fi
|
||||||
|
|
||||||
echo "${CMAKE_ARGS}"
|
echo "${CMAKE_ARGS}"
|
||||||
if [[ "X$INC_BUILD" = "Xoff" ]]; then
|
if [[ "X$INC_BUILD" = "Xoff" ]]; then
|
||||||
cmake ${CMAKE_ARGS} ../..
|
cmake ${CMAKE_ARGS} ../..
|
||||||
|
|
|
@ -37,6 +37,8 @@ class MS_API MSSession {
|
||||||
};
|
};
|
||||||
|
|
||||||
std::shared_ptr<FuncGraph> MS_API LoadModel(const char *model_buf, size_t size, const std::string &device);
|
std::shared_ptr<FuncGraph> MS_API LoadModel(const char *model_buf, size_t size, const std::string &device);
|
||||||
|
|
||||||
|
void MS_API ExitInference();
|
||||||
} // namespace inference
|
} // namespace inference
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_INCLUDE_MS_SESSION_H
|
#endif // MINDSPORE_INCLUDE_MS_SESSION_H
|
||||||
|
|
|
@ -247,7 +247,7 @@ add_library(inference SHARED
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/session/session.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/session/session.cc
|
||||||
${LOAD_ONNX_SRC}
|
${LOAD_ONNX_SRC}
|
||||||
)
|
)
|
||||||
target_link_libraries(inference PRIVATE ${PYTHON_LIB} ${SECUREC_LIBRARY}
|
target_link_libraries(inference PRIVATE ${PYTHON_LIBRARY} ${SECUREC_LIBRARY}
|
||||||
-Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_gvar mindspore::protobuf)
|
-Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_gvar mindspore::protobuf)
|
||||||
|
|
||||||
if (ENABLE_CPU)
|
if (ENABLE_CPU)
|
||||||
|
|
|
@ -38,6 +38,18 @@ std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const s
|
||||||
return anf_graph;
|
return anf_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ExitInference() {
|
||||||
|
auto ms_context = MsContext::GetInstance();
|
||||||
|
if (ms_context == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Get Context failed!";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!ms_context->CloseTsd()) {
|
||||||
|
MS_LOG(ERROR) << "Inference CloseTsd failed!";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<MSSession> MSSession::CreateSession(const std::string &device, uint32_t device_id) {
|
std::shared_ptr<MSSession> MSSession::CreateSession(const std::string &device, uint32_t device_id) {
|
||||||
auto session = std::make_shared<inference::Session>();
|
auto session = std::make_shared<inference::Session>();
|
||||||
auto ret = session->Init(device, device_id);
|
auto ret = session->Init(device, device_id);
|
||||||
|
@ -101,11 +113,14 @@ void Session::RegAllOp() {
|
||||||
|
|
||||||
uint32_t Session::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) {
|
uint32_t Session::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) {
|
||||||
MS_ASSERT(session_impl_ != nullptr);
|
MS_ASSERT(session_impl_ != nullptr);
|
||||||
return session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
|
auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
return graph_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) {
|
MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) {
|
||||||
std::vector<tensor::TensorPtr> inTensors;
|
std::vector<tensor::TensorPtr> inTensors;
|
||||||
|
inTensors.resize(inputs.size());
|
||||||
bool has_error = false;
|
bool has_error = false;
|
||||||
std::transform(inputs.begin(), inputs.end(), inTensors.begin(),
|
std::transform(inputs.begin(), inputs.end(), inTensors.begin(),
|
||||||
[&has_error](const std::shared_ptr<inference::MSTensor> &tensor_ptr) -> tensor::TensorPtr {
|
[&has_error](const std::shared_ptr<inference::MSTensor> &tensor_ptr) -> tensor::TensorPtr {
|
||||||
|
@ -144,6 +159,14 @@ int Session::Init(const std::string &device, uint32_t device_id) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
session_impl_->Init(device_id);
|
session_impl_->Init(device_id);
|
||||||
|
if (ms_context == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Get Context failed!";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (!ms_context->OpenTsd()) {
|
||||||
|
MS_LOG(ERROR) << "Session init OpenTsd failed!";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,69 @@
|
||||||
|
find_package(Threads REQUIRED)
|
||||||
|
|
||||||
|
# This branch assumes that gRPC and all its dependencies are already installed
|
||||||
|
# on this system, so they can be located by find_package().
|
||||||
|
|
||||||
|
# Find Protobuf installation
|
||||||
|
# Looks for protobuf-config.cmake file installed by Protobuf's cmake installation.
|
||||||
|
|
||||||
|
#set(protobuf_MODULE_COMPATIBLE TRUE)
|
||||||
|
#find_package(Protobuf CONFIG REQUIRED)
|
||||||
|
#message(STATUS "Using protobuf ${protobuf_VERSION}")
|
||||||
|
add_library(protobuf::libprotobuf ALIAS protobuf::protobuf)
|
||||||
|
add_executable(protobuf::libprotoc ALIAS protobuf::protoc)
|
||||||
|
|
||||||
|
set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf)
|
||||||
|
set(_REFLECTION gRPC::grpc++_reflection)
|
||||||
|
if(CMAKE_CROSSCOMPILING)
|
||||||
|
find_program(_PROTOBUF_PROTOC protoc)
|
||||||
|
else()
|
||||||
|
set(_PROTOBUF_PROTOC $<TARGET_FILE:protobuf::protoc>)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Find gRPC installation
|
||||||
|
# Looks for gRPCConfig.cmake file installed by gRPC's cmake installation.
|
||||||
|
find_package(gRPC CONFIG REQUIRED)
|
||||||
|
message(STATUS "Using gRPC ${gRPC_VERSION}")
|
||||||
|
|
||||||
|
set(_GRPC_GRPCPP gRPC::grpc++)
|
||||||
|
if(CMAKE_CROSSCOMPILING)
|
||||||
|
find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin)
|
||||||
|
else()
|
||||||
|
set(_GRPC_CPP_PLUGIN_EXECUTABLE $<TARGET_FILE:gRPC::grpc_cpp_plugin>)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Proto file
|
||||||
|
get_filename_component(hw_proto "ms_service.proto" ABSOLUTE)
|
||||||
|
get_filename_component(hw_proto_path "${hw_proto}" PATH)
|
||||||
|
|
||||||
|
# Generated sources
|
||||||
|
set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.cc")
|
||||||
|
set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.h")
|
||||||
|
set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.cc")
|
||||||
|
set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.h")
|
||||||
|
add_custom_command(
|
||||||
|
OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}"
|
||||||
|
COMMAND ${_PROTOBUF_PROTOC}
|
||||||
|
ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}"
|
||||||
|
--cpp_out "${CMAKE_CURRENT_BINARY_DIR}"
|
||||||
|
-I "${hw_proto_path}"
|
||||||
|
--plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}"
|
||||||
|
"${hw_proto}"
|
||||||
|
DEPENDS "${hw_proto}")
|
||||||
|
|
||||||
|
# Include generated *.pb.h files
|
||||||
|
include_directories("${CMAKE_CURRENT_BINARY_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}/core"
|
||||||
|
"${PROJECT_SOURCE_DIR}/mindspore/ccsrc")
|
||||||
|
file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
"core/*.cc" "core/util/*.cc" "core/version_control/*.cc")
|
||||||
|
|
||||||
|
list(APPEND SERVING_SRC "main.cc" ${hw_proto_srcs} ${hw_grpc_srcs} ${CORE_SRC_LIST})
|
||||||
|
|
||||||
|
include_directories(${CMAKE_BINARY_DIR})
|
||||||
|
add_executable(ms_serving ${SERVING_SRC})
|
||||||
|
target_link_libraries(ms_serving inference mindspore_gvar)
|
||||||
|
target_link_libraries(ms_serving ${_REFLECTION} ${_GRPC_GRPCPP} ${_PROTOBUF_LIBPROTOBUF} pthread)
|
||||||
|
if (ENABLE_D)
|
||||||
|
add_compile_definitions(ENABLE_D)
|
||||||
|
target_link_libraries(ms_serving ${RUNTIME_LIB})
|
||||||
|
endif()
|
|
@ -0,0 +1,36 @@
|
||||||
|
# serving
|
||||||
|
|
||||||
|
#### Description
|
||||||
|
A flexible, high-performance serving system for deep learning models
|
||||||
|
|
||||||
|
#### Software Architecture
|
||||||
|
Software architecture description
|
||||||
|
|
||||||
|
#### Installation
|
||||||
|
|
||||||
|
1. xxxx
|
||||||
|
2. xxxx
|
||||||
|
3. xxxx
|
||||||
|
|
||||||
|
#### Instructions
|
||||||
|
|
||||||
|
1. xxxx
|
||||||
|
2. xxxx
|
||||||
|
3. xxxx
|
||||||
|
|
||||||
|
#### Contribution
|
||||||
|
|
||||||
|
1. Fork the repository
|
||||||
|
2. Create Feat_xxx branch
|
||||||
|
3. Commit your code
|
||||||
|
4. Create Pull Request
|
||||||
|
|
||||||
|
|
||||||
|
#### Gitee Feature
|
||||||
|
|
||||||
|
1. You can use Readme\_XXX.md to support different languages, such as Readme\_en.md, Readme\_zh.md
|
||||||
|
2. Gitee blog [blog.gitee.com](https://blog.gitee.com)
|
||||||
|
3. Explore open source project [https://gitee.com/explore](https://gitee.com/explore)
|
||||||
|
4. The most valuable open source project [GVP](https://gitee.com/gvp)
|
||||||
|
5. The manual of Gitee [https://gitee.com/help](https://gitee.com/help)
|
||||||
|
6. The most popular members [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/)
|
|
@ -0,0 +1,37 @@
|
||||||
|
# serving
|
||||||
|
|
||||||
|
#### 介绍
|
||||||
|
A flexible, high-performance serving system for deep learning models
|
||||||
|
|
||||||
|
#### 软件架构
|
||||||
|
软件架构说明
|
||||||
|
|
||||||
|
|
||||||
|
#### 安装教程
|
||||||
|
|
||||||
|
1. xxxx
|
||||||
|
2. xxxx
|
||||||
|
3. xxxx
|
||||||
|
|
||||||
|
#### 使用说明
|
||||||
|
|
||||||
|
1. xxxx
|
||||||
|
2. xxxx
|
||||||
|
3. xxxx
|
||||||
|
|
||||||
|
#### 参与贡献
|
||||||
|
|
||||||
|
1. Fork 本仓库
|
||||||
|
2. 新建 Feat_xxx 分支
|
||||||
|
3. 提交代码
|
||||||
|
4. 新建 Pull Request
|
||||||
|
|
||||||
|
|
||||||
|
#### 码云特技
|
||||||
|
|
||||||
|
1. 使用 Readme\_XXX.md 来支持不同的语言,例如 Readme\_en.md, Readme\_zh.md
|
||||||
|
2. 码云官方博客 [blog.gitee.com](https://blog.gitee.com)
|
||||||
|
3. 你可以 [https://gitee.com/explore](https://gitee.com/explore) 这个地址来了解码云上的优秀开源项目
|
||||||
|
4. [GVP](https://gitee.com/gvp) 全称是码云最有价值开源项目,是码云综合评定出的优秀开源项目
|
||||||
|
5. 码云官方提供的使用手册 [https://gitee.com/help](https://gitee.com/help)
|
||||||
|
6. 码云封面人物是一档用来展示码云会员风采的栏目 [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/)
|
|
@ -0,0 +1,277 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "core/server.h"
|
||||||
|
#include <grpcpp/grpcpp.h>
|
||||||
|
#include <grpcpp/health_check_service_interface.h>
|
||||||
|
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
||||||
|
#include <string>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "mindspore/ccsrc/utils/log_adapter.h"
|
||||||
|
#include "serving/ms_service.grpc.pb.h"
|
||||||
|
#include "core/util/option_parser.h"
|
||||||
|
#include "core/version_control/version_controller.h"
|
||||||
|
#include "mindspore/ccsrc/utils/context/ms_context.h"
|
||||||
|
#include "core/util/file_system_operation.h"
|
||||||
|
#include "graphengine/third_party/fwkacllib/inc/runtime/context.h"
|
||||||
|
|
||||||
|
using ms_serving::MSService;
|
||||||
|
using ms_serving::PredictReply;
|
||||||
|
using ms_serving::PredictRequest;
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace serving {
|
||||||
|
using MSTensorPtr = std::shared_ptr<inference::MSTensor>;
|
||||||
|
|
||||||
|
Status Session::CreatDeviceSession(const std::string &device, uint32_t device_id) {
|
||||||
|
session_ = inference::MSSession::CreateSession(device + "Inference", device_id);
|
||||||
|
if (session_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Creat Session Failed";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
device_type_ = device;
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Session &Session::Instance() {
|
||||||
|
static Session instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Session::Predict(const std::vector<MSTensorPtr> &inputs, inference::MultiTensor *outputs) {
|
||||||
|
if (last_graph_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "the model has not loaded";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (session_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "the inference session has not be initialized";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
MS_LOG(INFO) << "run Predict";
|
||||||
|
|
||||||
|
*outputs = session_->RunGraph(graph_id_, inputs);
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Session::Warmup(const MindSporeModelPtr model) {
|
||||||
|
if (session_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "The CreatDeviceSession should be called, before warmup";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
size_t size = 0;
|
||||||
|
std::string file_name = model->GetModelPath() + '/' + model->GetModelName();
|
||||||
|
char *graphBuf = ReadFile(file_name.c_str(), &size);
|
||||||
|
if (graphBuf == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
last_graph_ = inference::LoadModel(graphBuf, size, device_type_);
|
||||||
|
graph_id_ = session_->CompileGraph(last_graph_);
|
||||||
|
MS_LOG(INFO) << "Session Warmup";
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Session::Clear() {
|
||||||
|
session_ = nullptr;
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
const std::map<ms_serving::DataType, TypeId> type2id_map{
|
||||||
|
{ms_serving::MS_UNKNOWN, TypeId::kNumberTypeBegin}, {ms_serving::MS_BOOL, TypeId::kNumberTypeBool},
|
||||||
|
{ms_serving::MS_INT8, TypeId::kNumberTypeInt8}, {ms_serving::MS_UINT8, TypeId::kNumberTypeUInt8},
|
||||||
|
{ms_serving::MS_INT16, TypeId::kNumberTypeInt16}, {ms_serving::MS_UINT16, TypeId::kNumberTypeUInt16},
|
||||||
|
{ms_serving::MS_INT32, TypeId::kNumberTypeInt32}, {ms_serving::MS_UINT32, TypeId::kNumberTypeUInt32},
|
||||||
|
{ms_serving::MS_INT64, TypeId::kNumberTypeInt64}, {ms_serving::MS_UINT64, TypeId::kNumberTypeUInt64},
|
||||||
|
{ms_serving::MS_FLOAT16, TypeId::kNumberTypeFloat16}, {ms_serving::MS_FLOAT32, TypeId::kNumberTypeFloat32},
|
||||||
|
{ms_serving::MS_FLOAT64, TypeId::kNumberTypeFloat64},
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::map<TypeId, ms_serving::DataType> id2type_map{
|
||||||
|
{TypeId::kNumberTypeBegin, ms_serving::MS_UNKNOWN}, {TypeId::kNumberTypeBool, ms_serving::MS_BOOL},
|
||||||
|
{TypeId::kNumberTypeInt8, ms_serving::MS_INT8}, {TypeId::kNumberTypeUInt8, ms_serving::MS_UINT8},
|
||||||
|
{TypeId::kNumberTypeInt16, ms_serving::MS_INT16}, {TypeId::kNumberTypeUInt16, ms_serving::MS_UINT16},
|
||||||
|
{TypeId::kNumberTypeInt32, ms_serving::MS_INT32}, {TypeId::kNumberTypeUInt32, ms_serving::MS_UINT32},
|
||||||
|
{TypeId::kNumberTypeInt64, ms_serving::MS_INT64}, {TypeId::kNumberTypeUInt64, ms_serving::MS_UINT64},
|
||||||
|
{TypeId::kNumberTypeFloat16, ms_serving::MS_FLOAT16}, {TypeId::kNumberTypeFloat32, ms_serving::MS_FLOAT32},
|
||||||
|
{TypeId::kNumberTypeFloat64, ms_serving::MS_FLOAT64},
|
||||||
|
};
|
||||||
|
const std::map<ms_serving::DataType, size_t> length_map{
|
||||||
|
{ms_serving::MS_UNKNOWN, 0},
|
||||||
|
{ms_serving::MS_BOOL, sizeof(bool)},
|
||||||
|
{ms_serving::MS_INT8, sizeof(int8_t)},
|
||||||
|
{ms_serving::MS_UINT8, sizeof(uint8_t)},
|
||||||
|
{ms_serving::MS_INT16, sizeof(int16_t)},
|
||||||
|
{ms_serving::MS_UINT16, sizeof(uint16_t)},
|
||||||
|
{ms_serving::MS_INT32, sizeof(int32_t)},
|
||||||
|
{ms_serving::MS_UINT32, sizeof(uint32_t)},
|
||||||
|
{ms_serving::MS_INT64, sizeof(int64_t)},
|
||||||
|
{ms_serving::MS_UINT64, sizeof(uint64_t)},
|
||||||
|
{ms_serving::MS_FLOAT16, 2},
|
||||||
|
{ms_serving::MS_FLOAT32, 4},
|
||||||
|
{ms_serving::MS_FLOAT64, 8},
|
||||||
|
};
|
||||||
|
MSTensorPtr ServingTensor2MSTensor(const ms_serving::Tensor &tensor) {
|
||||||
|
std::vector<int> shape;
|
||||||
|
for (auto dim : tensor.tensor_shape().dims()) {
|
||||||
|
shape.push_back(static_cast<int>(dim));
|
||||||
|
}
|
||||||
|
auto iter = type2id_map.find(tensor.tensor_type());
|
||||||
|
if (iter == type2id_map.end()) {
|
||||||
|
MS_LOG(ERROR) << "input tensor type is wrong, type is " << tensor.tensor_type();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
TypeId type = iter->second;
|
||||||
|
auto ms_tensor = std::shared_ptr<inference::MSTensor>(inference::MSTensor::CreateTensor(type, shape));
|
||||||
|
memcpy_s(ms_tensor->MutableData(), tensor.data().size(), tensor.data().data(), tensor.data().size());
|
||||||
|
return ms_tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
ms_serving::Tensor MSTensor2ServingTensor(MSTensorPtr ms_tensor) {
|
||||||
|
ms_serving::Tensor tensor;
|
||||||
|
ms_serving::TensorShape shape;
|
||||||
|
for (auto dim : ms_tensor->shape()) {
|
||||||
|
shape.add_dims(dim);
|
||||||
|
}
|
||||||
|
*tensor.mutable_tensor_shape() = shape;
|
||||||
|
auto iter = id2type_map.find(ms_tensor->data_type());
|
||||||
|
if (iter == id2type_map.end()) {
|
||||||
|
MS_LOG(ERROR) << "input tensor type is wrong, type is " << tensor.tensor_type();
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
tensor.set_tensor_type(iter->second);
|
||||||
|
tensor.set_data(ms_tensor->MutableData(), ms_tensor->Size());
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ClearEnv() {
|
||||||
|
Session::Instance().Clear();
|
||||||
|
inference::ExitInference();
|
||||||
|
}
|
||||||
|
void HandleSignal(int sig) {
|
||||||
|
ClearEnv();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef ENABLE_D
|
||||||
|
static rtContext_t g_ctx = nullptr;
|
||||||
|
#endif
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// Service Implement
|
||||||
|
class MSServiceImpl final : public MSService::Service {
|
||||||
|
grpc::Status Predict(grpc::ServerContext *context, const PredictRequest *request, PredictReply *reply) override {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
#ifdef ENABLE_D
|
||||||
|
if (g_ctx == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "rtCtx is nullptr";
|
||||||
|
return grpc::Status::CANCELLED;
|
||||||
|
}
|
||||||
|
rtError_t rt_ret = rtCtxSetCurrent(g_ctx);
|
||||||
|
if (rt_ret != RT_ERROR_NONE) {
|
||||||
|
MS_LOG(ERROR) << "set Ascend rtCtx failed";
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
std::vector<MSTensorPtr> inputs;
|
||||||
|
inference::MultiTensor outputs;
|
||||||
|
for (int i = 0; i < request->data_size(); i++) {
|
||||||
|
auto input = ServingTensor2MSTensor(request->data(i));
|
||||||
|
if (input == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Tensor convert failed";
|
||||||
|
return grpc::Status::CANCELLED;
|
||||||
|
}
|
||||||
|
inputs.push_back(input);
|
||||||
|
}
|
||||||
|
auto res = Session::Instance().Predict(inputs, &outputs);
|
||||||
|
if (res != SUCCESS) {
|
||||||
|
return grpc::Status::CANCELLED;
|
||||||
|
}
|
||||||
|
for (const auto &tensor : outputs) {
|
||||||
|
*reply->add_result() = MSTensor2ServingTensor(tensor);
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Finish call service Eval";
|
||||||
|
return grpc::Status::OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
grpc::Status Test(grpc::ServerContext *context, const PredictRequest *request, PredictReply *reply) override {
|
||||||
|
MS_LOG(INFO) << "TestService call";
|
||||||
|
return grpc::Status::OK;
|
||||||
|
}
|
||||||
|
std::mutex mutex_;
|
||||||
|
};
|
||||||
|
|
||||||
|
Status Server::BuildAndStart() {
|
||||||
|
// handle exit signal
|
||||||
|
signal(SIGINT, HandleSignal);
|
||||||
|
Status res;
|
||||||
|
auto option_args = Options::Instance().GetArgs();
|
||||||
|
std::string server_address = "0.0.0.0:" + std::to_string(option_args->grpc_port);
|
||||||
|
std::string model_path = option_args->model_path;
|
||||||
|
std::string model_name = option_args->model_name;
|
||||||
|
std::string device_type = option_args->device_type;
|
||||||
|
auto device_id = option_args->device_id;
|
||||||
|
res = Session::Instance().CreatDeviceSession(device_type, device_id);
|
||||||
|
if (res != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "creat session failed";
|
||||||
|
ClearEnv();
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
VersionController version_controller(option_args->poll_model_wait_seconds, model_path, model_name);
|
||||||
|
res = version_controller.Run();
|
||||||
|
if (res != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "load model failed";
|
||||||
|
ClearEnv();
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
#ifdef ENABLE_D
|
||||||
|
// set d context
|
||||||
|
rtContext_t ctx = nullptr;
|
||||||
|
rtError_t rt_ret = rtCtxGetCurrent(&ctx);
|
||||||
|
if (rt_ret != RT_ERROR_NONE || ctx == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "the ascend device context is null";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
g_ctx = ctx;
|
||||||
|
#endif
|
||||||
|
MSServiceImpl service;
|
||||||
|
grpc::EnableDefaultHealthCheckService(true);
|
||||||
|
grpc::reflection::InitProtoReflectionServerBuilderPlugin();
|
||||||
|
// Set the port is not reuseable
|
||||||
|
auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0);
|
||||||
|
grpc::ServerBuilder builder;
|
||||||
|
builder.SetOption(std::move(option));
|
||||||
|
// Listen on the given address without any authentication mechanism.
|
||||||
|
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||||
|
// Register "service" as the instance through which we'll communicate with
|
||||||
|
// clients. In this case it corresponds to an *synchronous* service.
|
||||||
|
builder.RegisterService(&service);
|
||||||
|
// Finally assemble the server.
|
||||||
|
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
|
||||||
|
MS_LOG(INFO) << "Server listening on " << server_address << std::endl;
|
||||||
|
|
||||||
|
// Wait for the server to shutdown. Note that some other thread must be
|
||||||
|
// responsible for shutting down the server for this call to ever return.
|
||||||
|
server->Wait();
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace serving
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,56 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_SERVER_H
|
||||||
|
#define MINDSPORE_SERVER_H
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <mutex>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "util/status.h"
|
||||||
|
#include "version_control/model.h"
|
||||||
|
#include "include/inference.h"
|
||||||
|
#include "mindspore/ccsrc/debug/info.h"
|
||||||
|
namespace mindspore {
|
||||||
|
namespace serving {
|
||||||
|
class Session {
|
||||||
|
public:
|
||||||
|
static Session &Instance();
|
||||||
|
Status CreatDeviceSession(const std::string &device, uint32_t device_id);
|
||||||
|
Status Predict(const std::vector<std::shared_ptr<inference::MSTensor>> &inputs, inference::MultiTensor *output);
|
||||||
|
Status Warmup(const MindSporeModelPtr model);
|
||||||
|
Status Clear();
|
||||||
|
|
||||||
|
private:
|
||||||
|
Session() = default;
|
||||||
|
~Session() = default;
|
||||||
|
int sesseion_id_{0};
|
||||||
|
std::shared_ptr<inference::MSSession> session_{nullptr};
|
||||||
|
FuncGraphPtr last_graph_{nullptr};
|
||||||
|
uint32_t graph_id_{0};
|
||||||
|
std::mutex mutex_;
|
||||||
|
std::string device_type_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Server {
|
||||||
|
public:
|
||||||
|
Server() = default;
|
||||||
|
~Server() = default;
|
||||||
|
Status BuildAndStart();
|
||||||
|
};
|
||||||
|
} // namespace serving
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_SERVER_H
|
|
@ -0,0 +1,102 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "core/util/file_system_operation.h"
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <dirent.h>
|
||||||
|
#include <sys/types.h>
|
||||||
|
#include <sys/stat.h>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <iostream>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <ctime>
|
||||||
|
#include <fstream>
|
||||||
|
#include <memory>
|
||||||
|
#include "mindspore/ccsrc/utils/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace serving {
|
||||||
|
|
||||||
|
char *ReadFile(const char *file, size_t *size) {
|
||||||
|
if (file == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "file is nullptr";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
MS_ASSERT(size != nullptr);
|
||||||
|
std::string realPath = file;
|
||||||
|
std::ifstream ifs(realPath);
|
||||||
|
if (!ifs.good()) {
|
||||||
|
MS_LOG(ERROR) << "file: " << realPath << " is not exist";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ifs.is_open()) {
|
||||||
|
MS_LOG(ERROR) << "file: " << realPath << "open failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
ifs.seekg(0, std::ios::end);
|
||||||
|
*size = ifs.tellg();
|
||||||
|
std::unique_ptr<char> buf(new (std::nothrow) char[*size]);
|
||||||
|
if (buf == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "malloc buf failed, file: " << realPath;
|
||||||
|
ifs.close();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
ifs.seekg(0, std::ios::beg);
|
||||||
|
ifs.read(buf.get(), *size);
|
||||||
|
ifs.close();
|
||||||
|
|
||||||
|
return buf.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DirOrFileExist(const std::string &file_path) {
|
||||||
|
int ret = access(file_path.c_str(), 0);
|
||||||
|
return (ret == -1) ? false : true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> GetAllSubDirs(const std::string &dir_path) {
|
||||||
|
DIR *dir;
|
||||||
|
struct dirent *ptr;
|
||||||
|
std::vector<std::string> SubDirs;
|
||||||
|
|
||||||
|
if ((dir = opendir(dir_path.c_str())) == NULL) {
|
||||||
|
MS_LOG(ERROR) << "Open " << dir_path << " error!";
|
||||||
|
return std::vector<std::string>();
|
||||||
|
}
|
||||||
|
|
||||||
|
while ((ptr = readdir(dir)) != NULL) {
|
||||||
|
std::string name = ptr->d_name;
|
||||||
|
if (name == "." || name == "..") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (ptr->d_type == DT_DIR) {
|
||||||
|
SubDirs.push_back(dir_path + "/" + name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
closedir(dir);
|
||||||
|
std::sort(SubDirs.begin(), SubDirs.end());
|
||||||
|
return SubDirs;
|
||||||
|
}
|
||||||
|
|
||||||
|
time_t GetModifyTime(const std::string &file_path) {
|
||||||
|
struct stat info;
|
||||||
|
(void)stat(file_path.c_str(), &info);
|
||||||
|
return info.st_mtime;
|
||||||
|
}
|
||||||
|
} // namespace serving
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,32 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_SERVING_FILE_SYSTEM_OPERATION_H_
|
||||||
|
#define MINDSPORE_SERVING_FILE_SYSTEM_OPERATION_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <ctime>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace serving {
|
||||||
|
char *ReadFile(const char *file, size_t *size);
|
||||||
|
bool DirOrFileExist(const std::string &file_path);
|
||||||
|
std::vector<std::string> GetAllSubDirs(const std::string &dir_path);
|
||||||
|
time_t GetModifyTime(const std::string &file_path);
|
||||||
|
} // namespace serving
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // !MINDSPORE_SERVING_FILE_SYSTEM_OPERATION_H_
|
|
@ -0,0 +1,243 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "core/util/option_parser.h"
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <cstring>
|
||||||
|
#include <iostream>
|
||||||
|
#include <iomanip>
|
||||||
|
#include "mindspore/ccsrc/utils/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace serving {
|
||||||
|
bool StartWith(const std::string &str, const std::string &expected) {
|
||||||
|
return expected.empty() ||
|
||||||
|
(str.size() >= expected.size() && memcmp(str.data(), expected.data(), expected.size()) == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RemovePrefix(std::string *str, const std::string &prefix) {
|
||||||
|
if (!StartWith(*str, prefix)) return false;
|
||||||
|
str->replace(str->begin(), str->begin() + prefix.size(), "");
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Option::ParseInt32(std::string *arg) {
|
||||||
|
if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) {
|
||||||
|
char extra;
|
||||||
|
int32_t parsed_value;
|
||||||
|
if (sscanf(arg->data(), "%d%c", &parsed_value, &extra) != 1) {
|
||||||
|
std::cout << "Parse " << name_ << "Error for option " << *arg << std::endl;
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
*int32_default_ = parsed_value;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Option::ParseBool(std::string *arg) {
|
||||||
|
if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) {
|
||||||
|
if (*arg == "true") {
|
||||||
|
*bool_default_ = true;
|
||||||
|
} else if (*arg == "false") {
|
||||||
|
*bool_default_ = false;
|
||||||
|
} else {
|
||||||
|
std::cout << "Parse " << name_ << " Error for option " << *arg << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Option::ParseString(std::string *arg) {
|
||||||
|
if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) {
|
||||||
|
*string_default_ = *arg;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Option::ParseFloat(std::string *arg) {
|
||||||
|
if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) {
|
||||||
|
char extra;
|
||||||
|
float parsed_value;
|
||||||
|
if (sscanf(arg->data(), "%f%c", &parsed_value, &extra) != 1) {
|
||||||
|
std::cout << "Parse " << name_ << "Error for option " << *arg << std::endl;
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
*float_default_ = parsed_value;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
Option::Option(const std::string &name, int32_t *default_point, const std::string &usage)
|
||||||
|
: name_(name),
|
||||||
|
type_(MS_TYPE_INT32),
|
||||||
|
int32_default_(default_point),
|
||||||
|
bool_default_(nullptr),
|
||||||
|
string_default_(nullptr),
|
||||||
|
float_default_(nullptr),
|
||||||
|
usage_(usage) {}
|
||||||
|
|
||||||
|
Option::Option(const std::string &name, bool *default_point, const std::string &usage)
|
||||||
|
: name_(name),
|
||||||
|
type_(MS_TYPE_BOOL),
|
||||||
|
int32_default_(nullptr),
|
||||||
|
bool_default_(default_point),
|
||||||
|
string_default_(nullptr),
|
||||||
|
float_default_(nullptr),
|
||||||
|
usage_(usage) {}
|
||||||
|
|
||||||
|
Option::Option(const std::string &name, std::string *default_point, const std::string &usage)
|
||||||
|
: name_(name),
|
||||||
|
type_(MS_TYPE_STRING),
|
||||||
|
int32_default_(nullptr),
|
||||||
|
bool_default_(nullptr),
|
||||||
|
string_default_(default_point),
|
||||||
|
float_default_(nullptr),
|
||||||
|
usage_(usage) {}
|
||||||
|
|
||||||
|
Option::Option(const std::string &name, float *default_point, const std::string &usage)
|
||||||
|
: name_(name),
|
||||||
|
type_(MS_TYPE_FLOAT),
|
||||||
|
int32_default_(nullptr),
|
||||||
|
bool_default_(nullptr),
|
||||||
|
string_default_(nullptr),
|
||||||
|
float_default_(default_point),
|
||||||
|
usage_(usage) {}
|
||||||
|
|
||||||
|
bool Option::Parse(std::string *arg) {
|
||||||
|
bool result = false;
|
||||||
|
switch (type_) {
|
||||||
|
case MS_TYPE_BOOL:
|
||||||
|
result = ParseBool(arg);
|
||||||
|
break;
|
||||||
|
case MS_TYPE_FLOAT:
|
||||||
|
result = ParseFloat(arg);
|
||||||
|
break;
|
||||||
|
case MS_TYPE_INT32:
|
||||||
|
result = ParseInt32(arg);
|
||||||
|
break;
|
||||||
|
case MS_TYPE_STRING:
|
||||||
|
result = ParseString(arg);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Options> Options::inst_ = nullptr;
|
||||||
|
|
||||||
|
Options &Options::Instance() {
|
||||||
|
static Options instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
Options::Options() : args_(nullptr) { CreateOptions(); }
|
||||||
|
|
||||||
|
void Options::CreateOptions() {
|
||||||
|
args_ = std::make_shared<Arguments>();
|
||||||
|
std::vector<Option> options = {
|
||||||
|
Option("port", &args_->grpc_port, "Port to listen on for gRPC API, default is 5500"),
|
||||||
|
Option("model_name", &args_->model_name, "model name "),
|
||||||
|
Option("model_path", &args_->model_path, "the path of the model files"),
|
||||||
|
Option("device_id", &args_->device_id, "the device id, default is 0"),
|
||||||
|
};
|
||||||
|
options_ = options;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Options::CheckOptions() {
|
||||||
|
if (args_->model_name == "" || args_->model_path == "") {
|
||||||
|
std::cout << "model_path and model_name should not be null" << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (args_->device_type != "Ascend") {
|
||||||
|
std::cout << "device_type only support Ascend right now" << std::endl;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Options::ParseCommandLine(int argc, char **argv) {
|
||||||
|
if (argc < 2 || (strcmp(argv[1], "--help") == 0)) {
|
||||||
|
Usage();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::vector<std::string> unkown_options;
|
||||||
|
for (int i = 1; i < argc; ++i) {
|
||||||
|
bool found = false;
|
||||||
|
for (auto &option : options_) {
|
||||||
|
std::string arg = argv[i];
|
||||||
|
if (option.Parse(&arg)) {
|
||||||
|
found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (found == false) {
|
||||||
|
unkown_options.push_back(argv[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!unkown_options.empty()) {
|
||||||
|
std::cout << "unkown options:" << std::endl;
|
||||||
|
for (const auto &option : unkown_options) {
|
||||||
|
std::cout << option << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool valid = (unkown_options.empty() && CheckOptions());
|
||||||
|
if (!valid) {
|
||||||
|
Usage();
|
||||||
|
}
|
||||||
|
return valid;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Options::Usage() {
|
||||||
|
std::cout << "USAGE: mindspore-serving [options]" << std::endl;
|
||||||
|
|
||||||
|
for (const auto &option : options_) {
|
||||||
|
std::string type;
|
||||||
|
switch (option.type_) {
|
||||||
|
case Option::MS_TYPE_BOOL:
|
||||||
|
type = "bool";
|
||||||
|
break;
|
||||||
|
case Option::MS_TYPE_FLOAT:
|
||||||
|
type = "float";
|
||||||
|
break;
|
||||||
|
case Option::MS_TYPE_INT32:
|
||||||
|
type = "int32";
|
||||||
|
break;
|
||||||
|
case Option::MS_TYPE_STRING:
|
||||||
|
type = "string";
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::cout << "--" << std::setw(30) << std::left << option.name_ << std::setw(10) << std::left << type
|
||||||
|
<< option.usage_ << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace serving
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,84 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_SERVING_OPTION_PARSER_H_
|
||||||
|
#define MINDSPORE_SERVING_OPTION_PARSER_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace serving {
|
||||||
|
|
||||||
|
struct Arguments {
|
||||||
|
int32_t grpc_port = 5500;
|
||||||
|
std::string grpc_socket_path;
|
||||||
|
std::string ssl_config_file;
|
||||||
|
int32_t poll_model_wait_seconds = 1;
|
||||||
|
std::string model_name;
|
||||||
|
std::string model_path;
|
||||||
|
std::string device_type = "Ascend";
|
||||||
|
int32_t device_id = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Option {
|
||||||
|
public:
|
||||||
|
Option(const std::string &name, int32_t *default_point, const std::string &usage);
|
||||||
|
Option(const std::string &name, bool *default_point, const std::string &usage);
|
||||||
|
Option(const std::string &name, std::string *default_point, const std::string &usage);
|
||||||
|
Option(const std::string &name, float *default_point, const std::string &usage);
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend class Options;
|
||||||
|
|
||||||
|
bool ParseInt32(std::string *arg);
|
||||||
|
bool ParseBool(std::string *arg);
|
||||||
|
bool ParseString(std::string *arg);
|
||||||
|
bool ParseFloat(std::string *arg);
|
||||||
|
bool Parse(std::string *arg);
|
||||||
|
std::string name_;
|
||||||
|
enum { MS_TYPE_INT32, MS_TYPE_BOOL, MS_TYPE_STRING, MS_TYPE_FLOAT } type_;
|
||||||
|
int32_t *int32_default_;
|
||||||
|
bool *bool_default_;
|
||||||
|
std::string *string_default_;
|
||||||
|
float *float_default_;
|
||||||
|
std::string usage_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Options {
|
||||||
|
public:
|
||||||
|
~Options() = default;
|
||||||
|
Options(const Options &) = delete;
|
||||||
|
Options &operator=(const Options &) = delete;
|
||||||
|
static Options &Instance();
|
||||||
|
bool ParseCommandLine(int argc, char **argv);
|
||||||
|
void Usage();
|
||||||
|
std::shared_ptr<Arguments> GetArgs() { return args_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
Options();
|
||||||
|
void CreateOptions();
|
||||||
|
bool CheckOptions();
|
||||||
|
static std::shared_ptr<Options> inst_;
|
||||||
|
std::string usage_;
|
||||||
|
std::vector<Option> options_;
|
||||||
|
std::shared_ptr<Arguments> args_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace serving
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,25 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_STATUS_H
|
||||||
|
#define MINDSPORE_STATUS_H
|
||||||
|
namespace mindspore {
|
||||||
|
namespace serving {
|
||||||
|
using Status = uint32_t;
|
||||||
|
enum ServingStatus { SUCCESS = 0, FAILED };
|
||||||
|
} // namespace serving
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_STATUS_H
|
|
@ -0,0 +1,33 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "core/version_control/model.h"
|
||||||
|
#include <string>
|
||||||
|
#include "mindspore/ccsrc/utils/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace serving {
|
||||||
|
|
||||||
|
MindSporeModel::MindSporeModel(const std::string &model_name, const std::string &model_path,
|
||||||
|
const std::string &model_version, const time_t &last_update_time)
|
||||||
|
: model_name_(model_name),
|
||||||
|
model_path_(model_path),
|
||||||
|
model_version_(model_version),
|
||||||
|
last_update_time_(last_update_time) {
|
||||||
|
MS_LOG(INFO) << "init mindspore model, model_name = " << model_name_ << ", model_path = " << model_path_
|
||||||
|
<< ", model_version = " << model_version_ << ", last_update_time = " << last_update_time_;
|
||||||
|
}
|
||||||
|
} // namespace serving
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_SERVING_MODEL_H_
|
||||||
|
#define MINDSPORE_SERVING_MODEL_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <ctime>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace serving {
|
||||||
|
class MindSporeModel {
|
||||||
|
public:
|
||||||
|
MindSporeModel(const std::string &model_name, const std::string &model_path, const std::string &model_version,
|
||||||
|
const time_t &last_update_time);
|
||||||
|
~MindSporeModel() = default;
|
||||||
|
std::string GetModelName() { return model_name_; }
|
||||||
|
std::string GetModelPath() { return model_path_; }
|
||||||
|
std::string GetModelVersion() { return model_version_; }
|
||||||
|
time_t GetLastUpdateTime() { return last_update_time_; }
|
||||||
|
void SetLastUpdateTime(const time_t &last_update_time) { last_update_time_ = last_update_time; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string model_name_;
|
||||||
|
std::string model_path_;
|
||||||
|
std::string model_version_;
|
||||||
|
time_t last_update_time_;
|
||||||
|
};
|
||||||
|
|
||||||
|
using MindSporeModelPtr = std::shared_ptr<MindSporeModel>;
|
||||||
|
} // namespace serving
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // !MINDSPORE_SERVING_MODEL_H_
|
|
@ -0,0 +1,134 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "core/version_control/version_controller.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <iostream>
|
||||||
|
#include <ctime>
|
||||||
|
#include <memory>
|
||||||
|
#include "util/file_system_operation.h"
|
||||||
|
#include "mindspore/ccsrc/utils/log_adapter.h"
|
||||||
|
#include "core/server.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace serving {
|
||||||
|
|
||||||
|
volatile bool stop_poll = false;
|
||||||
|
|
||||||
|
std::string GetVersionFromPath(const std::string &path) {
|
||||||
|
std::string new_path = path;
|
||||||
|
if (path.back() == '/') {
|
||||||
|
new_path = path.substr(0, path.size() - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string::size_type index = new_path.find_last_of("/");
|
||||||
|
std::string version = new_path.substr(index + 1);
|
||||||
|
return version;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PeriodicFunction::operator()() {
|
||||||
|
while (true) {
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(poll_model_wait_seconds_ * 1000));
|
||||||
|
std::vector<std::string> SubDirs = GetAllSubDirs(models_path_);
|
||||||
|
|
||||||
|
if (version_control_strategy_ == VersionController::VersionControllerStrategy::kLastest) {
|
||||||
|
auto path = SubDirs.empty() ? models_path_ : SubDirs.back();
|
||||||
|
std::string model_version = GetVersionFromPath(path);
|
||||||
|
time_t last_update_time = GetModifyTime(path);
|
||||||
|
if (model_version != valid_models_.back()->GetModelVersion()) {
|
||||||
|
MindSporeModelPtr model_ptr = std::make_shared<MindSporeModel>(valid_models_.front()->GetModelName(), path,
|
||||||
|
model_version, last_update_time);
|
||||||
|
valid_models_.back() = model_ptr;
|
||||||
|
Session::Instance().Warmup(valid_models_.back());
|
||||||
|
} else {
|
||||||
|
if (difftime(valid_models_.back()->GetLastUpdateTime(), last_update_time) < 0) {
|
||||||
|
valid_models_.back()->SetLastUpdateTime(last_update_time);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// not support
|
||||||
|
}
|
||||||
|
|
||||||
|
if (stop_poll == true) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
VersionController::VersionController(int32_t poll_model_wait_seconds, const std::string &models_path,
|
||||||
|
const std::string &model_name)
|
||||||
|
: version_control_strategy_(kLastest),
|
||||||
|
poll_model_wait_seconds_(poll_model_wait_seconds),
|
||||||
|
models_path_(models_path),
|
||||||
|
model_name_(model_name) {}
|
||||||
|
|
||||||
|
void StopPollModelPeriodic() { stop_poll = true; }
|
||||||
|
|
||||||
|
VersionController::~VersionController() {
|
||||||
|
StopPollModelPeriodic();
|
||||||
|
if (poll_model_thread_.joinable()) {
|
||||||
|
poll_model_thread_.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status VersionController::Run() {
|
||||||
|
Status ret;
|
||||||
|
ret = CreateInitModels();
|
||||||
|
if (ret != SUCCESS) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
// disable periodic check
|
||||||
|
// StartPollModelPeriodic();
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status VersionController::CreateInitModels() {
|
||||||
|
if (!DirOrFileExist(models_path_)) {
|
||||||
|
MS_LOG(ERROR) << "Model Path Not Exist!" << std::endl;
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
std::vector<std::string> SubDirs = GetAllSubDirs(models_path_);
|
||||||
|
if (version_control_strategy_ == kLastest) {
|
||||||
|
auto path = SubDirs.empty() ? models_path_ : SubDirs.back();
|
||||||
|
std::string model_version = GetVersionFromPath(path);
|
||||||
|
time_t last_update_time = GetModifyTime(path);
|
||||||
|
MindSporeModelPtr model_ptr = std::make_shared<MindSporeModel>(model_name_, path, model_version, last_update_time);
|
||||||
|
valid_models_.emplace_back(model_ptr);
|
||||||
|
} else {
|
||||||
|
for (auto &dir : SubDirs) {
|
||||||
|
std::string model_version = GetVersionFromPath(dir);
|
||||||
|
time_t last_update_time = GetModifyTime(dir);
|
||||||
|
MindSporeModelPtr model_ptr = std::make_shared<MindSporeModel>(model_name_, dir, model_version, last_update_time);
|
||||||
|
valid_models_.emplace_back(model_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (valid_models_.empty()) {
|
||||||
|
MS_LOG(ERROR) << "There is no valid model for serving";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
Session::Instance().Warmup(valid_models_.back());
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
void VersionController::StartPollModelPeriodic() {
|
||||||
|
poll_model_thread_ = std::thread(
|
||||||
|
PeriodicFunction(poll_model_wait_seconds_, models_path_, version_control_strategy_, std::ref(valid_models_)));
|
||||||
|
}
|
||||||
|
|
||||||
|
void VersionController::StopPollModelPeriodic() {}
|
||||||
|
|
||||||
|
} // namespace serving
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,71 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_SERVING_VERSOIN_CONTROLLER_H_
|
||||||
|
#define MINDSPORE_SERVING_VERSOIN_CONTROLLER_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <thread>
|
||||||
|
#include "./model.h"
|
||||||
|
#include "util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace serving {
|
||||||
|
class VersionController {
|
||||||
|
public:
|
||||||
|
enum VersionControllerStrategy { kLastest = 0, kMulti = 1 };
|
||||||
|
|
||||||
|
VersionController(int32_t poll_model_wait_seconds, const std::string &models_path, const std::string &model_name);
|
||||||
|
~VersionController();
|
||||||
|
Status Run();
|
||||||
|
void StartPollModelPeriodic();
|
||||||
|
void StopPollModelPeriodic();
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status CreateInitModels();
|
||||||
|
|
||||||
|
private:
|
||||||
|
VersionControllerStrategy version_control_strategy_;
|
||||||
|
std::vector<MindSporeModelPtr> valid_models_;
|
||||||
|
int32_t poll_model_wait_seconds_;
|
||||||
|
std::thread poll_model_thread_;
|
||||||
|
std::string models_path_;
|
||||||
|
std::string model_name_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class PeriodicFunction {
|
||||||
|
public:
|
||||||
|
PeriodicFunction(int32_t poll_model_wait_seconds, const std::string &models_path,
|
||||||
|
VersionController::VersionControllerStrategy version_control_strategy,
|
||||||
|
const std::vector<MindSporeModelPtr> &valid_models)
|
||||||
|
: poll_model_wait_seconds_(poll_model_wait_seconds),
|
||||||
|
models_path_(models_path),
|
||||||
|
version_control_strategy_(version_control_strategy),
|
||||||
|
valid_models_(valid_models) {}
|
||||||
|
~PeriodicFunction() = default;
|
||||||
|
void operator()();
|
||||||
|
|
||||||
|
private:
|
||||||
|
int32_t poll_model_wait_seconds_;
|
||||||
|
std::string models_path_;
|
||||||
|
VersionController::VersionControllerStrategy version_control_strategy_;
|
||||||
|
std::vector<MindSporeModelPtr> valid_models_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace serving
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // !MINDSPORE_SERVING_VERSOIN_CONTROLLER_H_
|
|
@ -0,0 +1,72 @@
|
||||||
|
cmake_minimum_required(VERSION 3.5.1)
|
||||||
|
|
||||||
|
project(HelloWorld C CXX)
|
||||||
|
|
||||||
|
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
|
||||||
|
|
||||||
|
|
||||||
|
find_package(Threads REQUIRED)
|
||||||
|
|
||||||
|
# This branch assumes that gRPC and all its dependencies are already installed
|
||||||
|
# on this system, so they can be located by find_package().
|
||||||
|
|
||||||
|
# Find Protobuf installation
|
||||||
|
# Looks for protobuf-config.cmake file installed by Protobuf's cmake installation.
|
||||||
|
set(protobuf_MODULE_COMPATIBLE TRUE)
|
||||||
|
find_package(Protobuf CONFIG REQUIRED)
|
||||||
|
message(STATUS "Using protobuf ${protobuf_VERSION}")
|
||||||
|
|
||||||
|
set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf)
|
||||||
|
set(_REFLECTION gRPC::grpc++_reflection)
|
||||||
|
if(CMAKE_CROSSCOMPILING)
|
||||||
|
find_program(_PROTOBUF_PROTOC protoc)
|
||||||
|
else()
|
||||||
|
set(_PROTOBUF_PROTOC $<TARGET_FILE:protobuf::protoc>)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Find gRPC installation
|
||||||
|
# Looks for gRPCConfig.cmake file installed by gRPC's cmake installation.
|
||||||
|
find_package(gRPC CONFIG REQUIRED)
|
||||||
|
message(STATUS "Using gRPC ${gRPC_VERSION}")
|
||||||
|
|
||||||
|
set(_GRPC_GRPCPP gRPC::grpc++)
|
||||||
|
if(CMAKE_CROSSCOMPILING)
|
||||||
|
find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin)
|
||||||
|
else()
|
||||||
|
set(_GRPC_CPP_PLUGIN_EXECUTABLE $<TARGET_FILE:gRPC::grpc_cpp_plugin>)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Proto file
|
||||||
|
get_filename_component(hw_proto "../ms_service.proto" ABSOLUTE)
|
||||||
|
get_filename_component(hw_proto_path "${hw_proto}" PATH)
|
||||||
|
|
||||||
|
# Generated sources
|
||||||
|
set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.cc")
|
||||||
|
set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.h")
|
||||||
|
set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.cc")
|
||||||
|
set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.h")
|
||||||
|
add_custom_command(
|
||||||
|
OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}"
|
||||||
|
COMMAND ${_PROTOBUF_PROTOC}
|
||||||
|
ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}"
|
||||||
|
--cpp_out "${CMAKE_CURRENT_BINARY_DIR}"
|
||||||
|
-I "${hw_proto_path}"
|
||||||
|
--plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}"
|
||||||
|
"${hw_proto}"
|
||||||
|
DEPENDS "${hw_proto}")
|
||||||
|
|
||||||
|
# Include generated *.pb.h files
|
||||||
|
include_directories("${CMAKE_CURRENT_BINARY_DIR}")
|
||||||
|
|
||||||
|
# Targets greeter_[async_](client|server)
|
||||||
|
foreach(_target
|
||||||
|
ms_client ms_server)
|
||||||
|
add_executable(${_target} "${_target}.cc"
|
||||||
|
${hw_proto_srcs}
|
||||||
|
${hw_grpc_srcs})
|
||||||
|
target_link_libraries(${_target}
|
||||||
|
${_REFLECTION}
|
||||||
|
${_GRPC_GRPCPP}
|
||||||
|
${_PROTOBUF_LIBPROTOBUF})
|
||||||
|
endforeach()
|
|
@ -0,0 +1,105 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <grpcpp/grpcpp.h>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "serving/ms_service.grpc.pb.h"
|
||||||
|
|
||||||
|
using grpc::Channel;
|
||||||
|
using grpc::ClientContext;
|
||||||
|
using grpc::Status;
|
||||||
|
using ms_serving::MSService;
|
||||||
|
using ms_serving::PredictReply;
|
||||||
|
using ms_serving::PredictRequest;
|
||||||
|
using ms_serving::Tensor;
|
||||||
|
using ms_serving::TensorShape;
|
||||||
|
|
||||||
|
class MSClient {
|
||||||
|
public:
|
||||||
|
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
|
||||||
|
|
||||||
|
std::string Predict(const std::string &user) {
|
||||||
|
// Data we are sending to the server.
|
||||||
|
PredictRequest request;
|
||||||
|
Tensor data;
|
||||||
|
TensorShape shape;
|
||||||
|
shape.add_dims(1);
|
||||||
|
shape.add_dims(1);
|
||||||
|
shape.add_dims(2);
|
||||||
|
shape.add_dims(2);
|
||||||
|
*data.mutable_tensor_shape() = shape;
|
||||||
|
data.set_tensor_type(ms_serving::MS_FLOAT32);
|
||||||
|
vector<float> input_data{1.1, 2.1, 3.1, 4.1};
|
||||||
|
data.set_data(input_data.data(), input_data.size());
|
||||||
|
*request.add_data() = data;
|
||||||
|
*request.add_data() = data;
|
||||||
|
|
||||||
|
// Container for the data we expect from the server.
|
||||||
|
PredictReply reply;
|
||||||
|
|
||||||
|
// Context for the client. It could be used to convey extra information to
|
||||||
|
// the server and/or tweak certain RPC behaviors.
|
||||||
|
ClientContext context;
|
||||||
|
|
||||||
|
// The actual RPC.
|
||||||
|
Status status = stub_->Predict(&context, request, &reply);
|
||||||
|
|
||||||
|
// Act upon its status.
|
||||||
|
if (status.ok()) {
|
||||||
|
return "RPC OK";
|
||||||
|
} else {
|
||||||
|
std::cout << status.error_code() << ": " << status.error_message() << std::endl;
|
||||||
|
return "RPC failed";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<MSService::Stub> stub_;
|
||||||
|
};
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
// Instantiate the client. It requires a channel, out of which the actual RPCs
|
||||||
|
// are created. This channel models a connection to an endpoint specified by
|
||||||
|
// the argument "--target=" which is the only expected argument.
|
||||||
|
// We indicate that the channel isn't authenticated (use of
|
||||||
|
// InsecureChannelCredentials()).
|
||||||
|
std::string target_str;
|
||||||
|
std::string arg_str("--target");
|
||||||
|
if (argc > 1) {
|
||||||
|
std::string arg_val = argv[1];
|
||||||
|
size_t start_pos = arg_val.find(arg_str);
|
||||||
|
if (start_pos != std::string::npos) {
|
||||||
|
start_pos += arg_str.size();
|
||||||
|
if (arg_val[start_pos] == '=') {
|
||||||
|
target_str = arg_val.substr(start_pos + 1);
|
||||||
|
} else {
|
||||||
|
std::cout << "The only correct argument syntax is --target=" << std::endl;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
std::cout << "The only acceptable argument is --target=" << std::endl;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
target_str = "localhost:85010";
|
||||||
|
}
|
||||||
|
MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()));
|
||||||
|
string request;
|
||||||
|
string reply = client.Predict(request);
|
||||||
|
std::cout << "client received: " << reply << std::endl;
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
|
@ -0,0 +1,67 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <grpcpp/grpcpp.h>
|
||||||
|
#include <grpcpp/health_check_service_interface.h>
|
||||||
|
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "serving/ms_service.grpc.pb.h"
|
||||||
|
|
||||||
|
using grpc::Server;
|
||||||
|
using grpc::ServerBuilder;
|
||||||
|
using grpc::ServerContext;
|
||||||
|
using grpc::Status;
|
||||||
|
using ms_serving::MSService;
|
||||||
|
using ms_serving::PredictReply;
|
||||||
|
using ms_serving::PredictRequest;
|
||||||
|
|
||||||
|
// Logic and data behind the server's behavior.
|
||||||
|
class MSServiceImpl final : public MSService::Service {
|
||||||
|
Status Predict(ServerContext *context, const PredictRequest *request, PredictReply *reply) override {
|
||||||
|
cout << "server eval" << endl;
|
||||||
|
return Status::OK;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void RunServer() {
|
||||||
|
std::string server_address("0.0.0.0:50051");
|
||||||
|
MSServiceImpl service;
|
||||||
|
|
||||||
|
grpc::EnableDefaultHealthCheckService(true);
|
||||||
|
grpc::reflection::InitProtoReflectionServerBuilderPlugin();
|
||||||
|
auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0);
|
||||||
|
|
||||||
|
ServerBuilder builder;
|
||||||
|
builder.SetOption(std::move(option));
|
||||||
|
// Listen on the given address without any authentication mechanism.
|
||||||
|
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||||
|
// Register "service" as the instance through which we'll communicate with
|
||||||
|
// clients. In this case it corresponds to an *synchronous* service.
|
||||||
|
builder.RegisterService(&service);
|
||||||
|
// Finally assemble the server.
|
||||||
|
std::unique_ptr<Server> server(builder.BuildAndStart());
|
||||||
|
std::cout << "Server listening on " << server_address << std::endl;
|
||||||
|
|
||||||
|
// Wait for the server to shutdown. Note that some other thread must be
|
||||||
|
// responsible for shutting down the server for this call to ever return.
|
||||||
|
server->Wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
RunServer();
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
|
@ -0,0 +1,29 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "core/server.h"
|
||||||
|
#include "core/util/option_parser.h"
|
||||||
|
|
||||||
|
using mindspore::serving::Options;
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
auto flag = Options::Instance().ParseCommandLine(argc, argv);
|
||||||
|
if (!flag) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
mindspore::serving::Server server;
|
||||||
|
server.BuildAndStart();
|
||||||
|
return 0;
|
||||||
|
}
|
|
@ -0,0 +1,48 @@
|
||||||
|
// ms_service.proto
|
||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package ms_serving;
|
||||||
|
|
||||||
|
service MSService {
|
||||||
|
rpc Predict(PredictRequest) returns (PredictReply) {}
|
||||||
|
rpc Test(PredictRequest) returns (PredictReply) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
message PredictRequest {
|
||||||
|
repeated Tensor data = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message PredictReply {
|
||||||
|
repeated Tensor result = 1;
|
||||||
|
}
|
||||||
|
enum DataType {
|
||||||
|
MS_UNKNOWN = 0;
|
||||||
|
MS_BOOL = 1;
|
||||||
|
MS_INT8 = 2;
|
||||||
|
MS_UINT8 = 3;
|
||||||
|
MS_INT16 = 4;
|
||||||
|
MS_UINT16 = 5;
|
||||||
|
MS_INT32 = 6;
|
||||||
|
MS_UINT32 = 7;
|
||||||
|
MS_INT64 = 8;
|
||||||
|
MS_UINT64 = 9;
|
||||||
|
MS_FLOAT16 = 10;
|
||||||
|
MS_FLOAT32 = 11;
|
||||||
|
MS_FLOAT64 = 12;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TensorShape {
|
||||||
|
repeated int64 dims = 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
message Tensor {
|
||||||
|
// tensor shape info
|
||||||
|
TensorShape tensor_shape = 1;
|
||||||
|
|
||||||
|
// tensor content data type
|
||||||
|
DataType tensor_type = 2;
|
||||||
|
|
||||||
|
// tensor data
|
||||||
|
bytes data = 3;
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import grpc
|
||||||
|
import numpy as np
|
||||||
|
import ms_service_pb2
|
||||||
|
import ms_service_pb2_grpc
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
channel = grpc.insecure_channel('localhost:50051')
|
||||||
|
stub = ms_service_pb2_grpc.MSServiceStub(channel)
|
||||||
|
# request = ms_service_pb2.PredictRequest()
|
||||||
|
# request.name = 'haha'
|
||||||
|
# response = stub.Eval(request)
|
||||||
|
# print("ms client received: " + response.message)
|
||||||
|
|
||||||
|
request = ms_service_pb2.PredictRequest()
|
||||||
|
request.data.tensor_shape.dims.extend([32, 1, 32, 32])
|
||||||
|
request.data.tensor_type = ms_service_pb2.MS_FLOAT32
|
||||||
|
request.data.data = (np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01).tobytes()
|
||||||
|
|
||||||
|
request.label.tensor_shape.dims.extend([32])
|
||||||
|
request.label.tensor_type = ms_service_pb2.MS_INT32
|
||||||
|
request.label.data = np.ones([32]).astype(np.int32).tobytes()
|
||||||
|
|
||||||
|
result = stub.Predict(request)
|
||||||
|
#result_np = np.frombuffer(result.result.data, dtype=np.float32).reshape(result.result.tensor_shape.dims)
|
||||||
|
print("ms client received: ")
|
||||||
|
#print(result_np)
|
||||||
|
|
||||||
|
# future_list = []
|
||||||
|
# times = 1000
|
||||||
|
# for i in range(times):
|
||||||
|
# async_future = stub.Eval.future(request)
|
||||||
|
# future_list.append(async_future)
|
||||||
|
# print("async call, future list add item " + str(i));
|
||||||
|
#
|
||||||
|
# for i in range(len(future_list)):
|
||||||
|
# async_result = future_list[i].result()
|
||||||
|
# print("ms client async get result of item " + str(i))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run()
|
|
@ -0,0 +1,46 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import grpc
|
||||||
|
import numpy as np
|
||||||
|
import ms_service_pb2
|
||||||
|
import ms_service_pb2_grpc
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
channel = grpc.insecure_channel('localhost:50051')
|
||||||
|
stub = ms_service_pb2_grpc.MSServiceStub(channel)
|
||||||
|
# request = ms_service_pb2.EvalRequest()
|
||||||
|
# request.name = 'haha'
|
||||||
|
# response = stub.Eval(request)
|
||||||
|
# print("ms client received: " + response.message)
|
||||||
|
|
||||||
|
request = ms_service_pb2.PredictRequest()
|
||||||
|
request.data.tensor_shape.dims.extend([32, 1, 32, 32])
|
||||||
|
request.data.tensor_type = ms_service_pb2.MS_FLOAT32
|
||||||
|
request.data.data = (np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01).tobytes()
|
||||||
|
|
||||||
|
request.label.tensor_shape.dims.extend([32])
|
||||||
|
request.label.tensor_type = ms_service_pb2.MS_INT32
|
||||||
|
request.label.data = np.ones([32]).astype(np.int32).tobytes()
|
||||||
|
|
||||||
|
result = stub.Test(request)
|
||||||
|
#result_np = np.frombuffer(result.result.data, dtype=np.float32).reshape(result.result.tensor_shape.dims)
|
||||||
|
print("ms client test call received: ")
|
||||||
|
#print(result_np)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run()
|
|
@ -0,0 +1,55 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
from concurrent import futures
|
||||||
|
import time
|
||||||
|
import grpc
|
||||||
|
import numpy as np
|
||||||
|
import ms_service_pb2
|
||||||
|
import ms_service_pb2_grpc
|
||||||
|
import test_cpu_lenet
|
||||||
|
from mindspore import Tensor
|
||||||
|
|
||||||
|
class MSService(ms_service_pb2_grpc.MSServiceServicer):
|
||||||
|
def Predict(self, request, context):
|
||||||
|
request_data = request.data
|
||||||
|
request_label = request.label
|
||||||
|
|
||||||
|
data_from_buffer = np.frombuffer(request_data.data, dtype=np.float32)
|
||||||
|
data_from_buffer = data_from_buffer.reshape(request_data.tensor_shape.dims)
|
||||||
|
data = Tensor(data_from_buffer)
|
||||||
|
|
||||||
|
label_from_buffer = np.frombuffer(request_label.data, dtype=np.int32)
|
||||||
|
label_from_buffer = label_from_buffer.reshape(request_label.tensor_shape.dims)
|
||||||
|
label = Tensor(label_from_buffer)
|
||||||
|
|
||||||
|
result = test_cpu_lenet.test_lenet(data, label)
|
||||||
|
result_reply = ms_service_pb2.PredictReply()
|
||||||
|
result_reply.result.tensor_shape.dims.extend(result.shape())
|
||||||
|
result_reply.result.data = result.asnumpy().tobytes()
|
||||||
|
return result_reply
|
||||||
|
|
||||||
|
def serve():
|
||||||
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
|
||||||
|
ms_service_pb2_grpc.add_MSServiceServicer_to_server(MSService(), server)
|
||||||
|
server.add_insecure_port('[::]:50051')
|
||||||
|
server.start()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(60*60*24) # one day in seconds
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
server.stop(0)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
serve()
|
|
@ -0,0 +1,96 @@
|
||||||
|
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
import ms_service_pb2 as ms__service__pb2
|
||||||
|
|
||||||
|
|
||||||
|
class MSServiceStub(object):
|
||||||
|
"""Missing associated documentation comment in .proto file"""
|
||||||
|
|
||||||
|
def __init__(self, channel):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel: A grpc.Channel.
|
||||||
|
"""
|
||||||
|
self.Predict = channel.unary_unary(
|
||||||
|
'/ms_serving.MSService/Predict',
|
||||||
|
request_serializer=ms__service__pb2.PredictRequest.SerializeToString,
|
||||||
|
response_deserializer=ms__service__pb2.PredictReply.FromString,
|
||||||
|
)
|
||||||
|
self.Test = channel.unary_unary(
|
||||||
|
'/ms_serving.MSService/Test',
|
||||||
|
request_serializer=ms__service__pb2.PredictRequest.SerializeToString,
|
||||||
|
response_deserializer=ms__service__pb2.PredictReply.FromString,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MSServiceServicer(object):
|
||||||
|
"""Missing associated documentation comment in .proto file"""
|
||||||
|
|
||||||
|
def Predict(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file"""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def Test(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file"""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
|
||||||
|
def add_MSServiceServicer_to_server(servicer, server):
|
||||||
|
rpc_method_handlers = {
|
||||||
|
'Predict': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.Predict,
|
||||||
|
request_deserializer=ms__service__pb2.PredictRequest.FromString,
|
||||||
|
response_serializer=ms__service__pb2.PredictReply.SerializeToString,
|
||||||
|
),
|
||||||
|
'Test': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.Test,
|
||||||
|
request_deserializer=ms__service__pb2.PredictRequest.FromString,
|
||||||
|
response_serializer=ms__service__pb2.PredictReply.SerializeToString,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
generic_handler = grpc.method_handlers_generic_handler(
|
||||||
|
'ms_serving.MSService', rpc_method_handlers)
|
||||||
|
server.add_generic_rpc_handlers((generic_handler,))
|
||||||
|
|
||||||
|
|
||||||
|
# This class is part of an EXPERIMENTAL API.
|
||||||
|
class MSService(object):
|
||||||
|
"""Missing associated documentation comment in .proto file"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def Predict(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(request, target, '/ms_serving.MSService/Predict',
|
||||||
|
ms__service__pb2.PredictRequest.SerializeToString,
|
||||||
|
ms__service__pb2.PredictReply.FromString,
|
||||||
|
options, channel_credentials,
|
||||||
|
call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def Test(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(request, target, '/ms_serving.MSService/Test',
|
||||||
|
ms__service__pb2.PredictRequest.SerializeToString,
|
||||||
|
ms__service__pb2.PredictReply.FromString,
|
||||||
|
options, channel_credentials,
|
||||||
|
call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
@ -0,0 +1,91 @@
|
||||||
|
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||||
|
from mindspore.nn.optim import Momentum
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
import ms_service_pb2
|
||||||
|
|
||||||
|
|
||||||
|
class LeNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(LeNet, self).__init__()
|
||||||
|
self.relu = P.ReLU()
|
||||||
|
self.batch_size = 32
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
|
||||||
|
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
|
||||||
|
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.fc1 = nn.Dense(400, 120)
|
||||||
|
self.fc2 = nn.Dense(120, 84)
|
||||||
|
self.fc3 = nn.Dense(84, 10)
|
||||||
|
|
||||||
|
def construct(self, input_x):
|
||||||
|
output = self.conv1(input_x)
|
||||||
|
output = self.relu(output)
|
||||||
|
output = self.pool(output)
|
||||||
|
output = self.conv2(output)
|
||||||
|
output = self.relu(output)
|
||||||
|
output = self.pool(output)
|
||||||
|
output = self.reshape(output, (self.batch_size, -1))
|
||||||
|
output = self.fc1(output)
|
||||||
|
output = self.relu(output)
|
||||||
|
output = self.fc2(output)
|
||||||
|
output = self.relu(output)
|
||||||
|
output = self.fc3(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def train(net, data, label):
|
||||||
|
learning_rate = 0.01
|
||||||
|
momentum = 0.9
|
||||||
|
|
||||||
|
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
|
||||||
|
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||||
|
net_with_criterion = WithLossCell(net, criterion)
|
||||||
|
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
|
||||||
|
train_network.set_train()
|
||||||
|
res = train_network(data, label)
|
||||||
|
print("+++++++++Loss+++++++++++++")
|
||||||
|
print(res)
|
||||||
|
print("+++++++++++++++++++++++++++")
|
||||||
|
assert res
|
||||||
|
return res
|
||||||
|
|
||||||
|
def test_lenet(data, label):
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||||
|
net = LeNet()
|
||||||
|
return train(net, data, label)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tensor = ms_service_pb2.Tensor()
|
||||||
|
tensor.tensor_shape.dim.extend([32, 1, 32, 32])
|
||||||
|
# tensor.tensor_shape.dim.add() = 1
|
||||||
|
# tensor.tensor_shape.dim.add() = 32
|
||||||
|
# tensor.tensor_shape.dim.add() = 32
|
||||||
|
tensor.tensor_type = ms_service_pb2.MS_FLOAT32
|
||||||
|
tensor.data = np.ones([32, 1, 32, 32]).astype(np.float32).tobytes()
|
||||||
|
|
||||||
|
data_from_buffer = np.frombuffer(tensor.data, dtype=np.float32)
|
||||||
|
print(tensor.tensor_shape.dim)
|
||||||
|
data_from_buffer = data_from_buffer.reshape(tensor.tensor_shape.dim)
|
||||||
|
print(data_from_buffer.shape)
|
||||||
|
input_data = Tensor(data_from_buffer * 0.01)
|
||||||
|
input_label = Tensor(np.ones([32]).astype(np.int32))
|
||||||
|
test_lenet(input_data, input_label)
|
|
@ -0,0 +1,105 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
CLANG_FORMAT=$(which clang-format) || (echo "Please install 'clang-format' tool first"; exit 1)
|
||||||
|
|
||||||
|
version=$("${CLANG_FORMAT}" --version | sed -n "s/.*\ \([0-9]*\)\.[0-9]*\.[0-9]*.*/\1/p")
|
||||||
|
if [[ "${version}" -lt "8" ]]; then
|
||||||
|
echo "clang-format's version must be at least 8.0.0"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
CURRENT_PATH=$(pwd)
|
||||||
|
SCRIPTS_PATH=$(dirname "$0")
|
||||||
|
|
||||||
|
echo "CURRENT_PATH=${CURRENT_PATH}"
|
||||||
|
echo "SCRIPTS_PATH=${SCRIPTS_PATH}"
|
||||||
|
|
||||||
|
# print usage message
|
||||||
|
function usage()
|
||||||
|
{
|
||||||
|
echo "Format the specified source files to conform the code style."
|
||||||
|
echo "Usage:"
|
||||||
|
echo "bash $0 [-a] [-c] [-l] [-h]"
|
||||||
|
echo "e.g. $0 -c"
|
||||||
|
echo ""
|
||||||
|
echo "Options:"
|
||||||
|
echo " -a format of all files"
|
||||||
|
echo " -c format of the files changed compared to last commit, default case"
|
||||||
|
echo " -l format of the files changed in last commit"
|
||||||
|
echo " -h Print usage"
|
||||||
|
}
|
||||||
|
|
||||||
|
# check and set options
|
||||||
|
function checkopts()
|
||||||
|
{
|
||||||
|
# init variable
|
||||||
|
mode="changed" # default format changed files
|
||||||
|
|
||||||
|
# Process the options
|
||||||
|
while getopts 'aclh' opt
|
||||||
|
do
|
||||||
|
case "${opt}" in
|
||||||
|
a)
|
||||||
|
mode="all"
|
||||||
|
;;
|
||||||
|
c)
|
||||||
|
mode="changed"
|
||||||
|
;;
|
||||||
|
l)
|
||||||
|
mode="lastcommit"
|
||||||
|
;;
|
||||||
|
h)
|
||||||
|
usage
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown option ${opt}!"
|
||||||
|
usage
|
||||||
|
exit 1
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
# init variable
|
||||||
|
# check options
|
||||||
|
checkopts "$@"
|
||||||
|
|
||||||
|
# switch to project root path, which contains clang-format config file '.clang-format'
|
||||||
|
cd "${SCRIPTS_PATH}/.." || exit 1
|
||||||
|
|
||||||
|
FMT_FILE_LIST='__format_files_list__'
|
||||||
|
|
||||||
|
if [[ "X${mode}" == "Xall" ]]; then
|
||||||
|
find ./ -type f -name "*" | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true
|
||||||
|
elif [[ "X${mode}" == "Xchanged" ]]; then
|
||||||
|
git diff --name-only | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true
|
||||||
|
else # "X${mode}" == "Xlastcommit"
|
||||||
|
git diff --name-only HEAD~ HEAD | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true
|
||||||
|
fi
|
||||||
|
|
||||||
|
while read line; do
|
||||||
|
if [ -f "${line}" ]; then
|
||||||
|
${CLANG_FORMAT} -i "${line}"
|
||||||
|
fi
|
||||||
|
done < "${FMT_FILE_LIST}"
|
||||||
|
|
||||||
|
rm "${FMT_FILE_LIST}"
|
||||||
|
cd "${CURRENT_PATH}" || exit 1
|
||||||
|
|
||||||
|
echo "Specified cpp source files have been format successfully."
|
Loading…
Reference in New Issue