forked from mindspore-Ecosystem/mindspore
!9765 remove serving in mindspore repo
From: @xu-yfei Reviewed-by: @zhoufeng54,@kisnwang Signed-off-by: @kisnwang
This commit is contained in:
commit
058fbd2d1f
|
@ -81,9 +81,4 @@ if (ENABLE_TESTCASES)
|
||||||
add_subdirectory(tests)
|
add_subdirectory(tests)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (ENABLE_SERVING)
|
|
||||||
add_subdirectory(serving)
|
|
||||||
add_subdirectory(serving/example/cpp_client)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
include(cmake/package.cmake)
|
include(cmake/package.cmake)
|
||||||
|
|
19
build.sh
19
build.sh
|
@ -54,7 +54,6 @@ usage()
|
||||||
echo " -V Specify the device version, if -e gpu, default CUDA 10.1, if -e ascend, default Ascend 910"
|
echo " -V Specify the device version, if -e gpu, default CUDA 10.1, if -e ascend, default Ascend 910"
|
||||||
echo " -I Enable compiling mindspore lite for arm64, arm32 or x86_64, default disable mindspore lite compilation"
|
echo " -I Enable compiling mindspore lite for arm64, arm32 or x86_64, default disable mindspore lite compilation"
|
||||||
echo " -K Compile with AKG, default on"
|
echo " -K Compile with AKG, default on"
|
||||||
echo " -s Enable serving module, default off"
|
|
||||||
echo " -B Enable debugger, default on"
|
echo " -B Enable debugger, default on"
|
||||||
echo " -E Enable IBVERBS for parameter server, default off"
|
echo " -E Enable IBVERBS for parameter server, default off"
|
||||||
echo " -l Compile with python dependency, default on"
|
echo " -l Compile with python dependency, default on"
|
||||||
|
@ -105,7 +104,6 @@ checkopts()
|
||||||
SUPPORT_TRAIN="off"
|
SUPPORT_TRAIN="off"
|
||||||
USE_GLOG="on"
|
USE_GLOG="on"
|
||||||
ENABLE_AKG="on"
|
ENABLE_AKG="on"
|
||||||
ENABLE_SERVING="off"
|
|
||||||
ENABLE_ACL="off"
|
ENABLE_ACL="off"
|
||||||
ENABLE_DEBUGGER="on"
|
ENABLE_DEBUGGER="on"
|
||||||
ENABLE_IBVERBS="off"
|
ENABLE_IBVERBS="off"
|
||||||
|
@ -123,7 +121,7 @@ checkopts()
|
||||||
DEVICE=""
|
DEVICE=""
|
||||||
ENABLE_NPU="off"
|
ENABLE_NPU="off"
|
||||||
# Process the options
|
# Process the options
|
||||||
while getopts 'drvj:c:t:hsb:a:g:p:ie:m:l:I:LRP:D:zM:V:K:swB:En:T:A:C:o:S:k:W:' opt
|
while getopts 'drvj:c:t:hsb:a:g:p:ie:m:l:I:LRP:D:zM:V:K:B:En:T:A:C:o:S:k:W:' opt
|
||||||
do
|
do
|
||||||
OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]')
|
OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]')
|
||||||
case "${opt}" in
|
case "${opt}" in
|
||||||
|
@ -273,16 +271,6 @@ checkopts()
|
||||||
ENABLE_AKG="on"
|
ENABLE_AKG="on"
|
||||||
echo "enable compile with akg"
|
echo "enable compile with akg"
|
||||||
;;
|
;;
|
||||||
s)
|
|
||||||
ENABLE_SERVING="on"
|
|
||||||
echo "enable serving"
|
|
||||||
;;
|
|
||||||
w)
|
|
||||||
ENABLE_SERVING="on"
|
|
||||||
echo "enable serving"
|
|
||||||
ENABLE_ACL="on"
|
|
||||||
echo "enable acl"
|
|
||||||
;;
|
|
||||||
B)
|
B)
|
||||||
check_on_off $OPTARG B
|
check_on_off $OPTARG B
|
||||||
ENABLE_DEBUGGER="$OPTARG"
|
ENABLE_DEBUGGER="$OPTARG"
|
||||||
|
@ -366,12 +354,10 @@ checkopts()
|
||||||
DEVICE_VERSION=910
|
DEVICE_VERSION=910
|
||||||
fi
|
fi
|
||||||
if [[ "X$DEVICE_VERSION" == "X310" ]]; then
|
if [[ "X$DEVICE_VERSION" == "X310" ]]; then
|
||||||
ENABLE_SERVING="on"
|
|
||||||
ENABLE_ACL="on"
|
ENABLE_ACL="on"
|
||||||
elif [[ "X$DEVICE_VERSION" == "X910" ]]; then
|
elif [[ "X$DEVICE_VERSION" == "X910" ]]; then
|
||||||
ENABLE_D="on"
|
ENABLE_D="on"
|
||||||
ENABLE_CPU="on"
|
ENABLE_CPU="on"
|
||||||
ENABLE_SERVING="on"
|
|
||||||
else
|
else
|
||||||
echo "Invalid value ${DEVICE_VERSION} for option -V"
|
echo "Invalid value ${DEVICE_VERSION} for option -V"
|
||||||
usage
|
usage
|
||||||
|
@ -467,9 +453,6 @@ build_mindspore()
|
||||||
if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" || "X$ENABLE_GPU" = "Xon" ]]; then
|
if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" || "X$ENABLE_GPU" = "Xon" ]]; then
|
||||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AKG=ON"
|
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AKG=ON"
|
||||||
fi
|
fi
|
||||||
if [[ "X$ENABLE_SERVING" = "Xon" ]]; then
|
|
||||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_SERVING=ON"
|
|
||||||
fi
|
|
||||||
if [[ "X$ENABLE_ACL" = "Xon" ]]; then
|
if [[ "X$ENABLE_ACL" = "Xon" ]]; then
|
||||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_ACL=ON"
|
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_ACL=ON"
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -88,7 +88,7 @@ if (ENABLE_MINDDATA)
|
||||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/sentencepiece.cmake)
|
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/sentencepiece.cmake)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (ENABLE_MINDDATA OR ENABLE_SERVING)
|
if (ENABLE_MINDDATA)
|
||||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/jpeg_turbo.cmake)
|
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/jpeg_turbo.cmake)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
|
@ -119,7 +119,7 @@ if(ENABLE_DEBUGGER)
|
||||||
add_compile_definitions(ENABLE_DEBUGGER)
|
add_compile_definitions(ENABLE_DEBUGGER)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (ENABLE_DEBUGGER OR ENABLE_SERVING OR ENABLE_TESTCASES)
|
if (ENABLE_DEBUGGER OR ENABLE_TESTCASES)
|
||||||
set(MS_BUILD_GRPC ON)
|
set(MS_BUILD_GRPC ON)
|
||||||
endif()
|
endif()
|
||||||
if (ENABLE_MINDDATA AND NOT CMAKE_SYSTEM_NAME MATCHES "Windows")
|
if (ENABLE_MINDDATA AND NOT CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||||
|
|
|
@ -202,7 +202,7 @@ if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (ENABLE_SERVING OR ENABLE_TESTCASES)
|
if (ENABLE_TESTCASES)
|
||||||
file(GLOB_RECURSE LIBEVENT_LIB_LIST
|
file(GLOB_RECURSE LIBEVENT_LIB_LIST
|
||||||
${libevent_LIBPATH}/libevent*
|
${libevent_LIBPATH}/libevent*
|
||||||
${libevent_LIBPATH}/libevent_pthreads*
|
${libevent_LIBPATH}/libevent_pthreads*
|
||||||
|
@ -336,29 +336,3 @@ install(
|
||||||
COMPONENT mindspore
|
COMPONENT mindspore
|
||||||
)
|
)
|
||||||
|
|
||||||
if (ENABLE_SERVING)
|
|
||||||
install(
|
|
||||||
TARGETS ms_serving
|
|
||||||
DESTINATION ${INSTALL_BASE_DIR}
|
|
||||||
COMPONENT mindspore
|
|
||||||
)
|
|
||||||
|
|
||||||
install(
|
|
||||||
FILES ${CMAKE_SOURCE_DIR}/build/mindspore/serving/ms_service_pb2.py
|
|
||||||
${CMAKE_SOURCE_DIR}/build/mindspore/serving/ms_service_pb2_grpc.py
|
|
||||||
DESTINATION ${INSTALL_PY_DIR}
|
|
||||||
COMPONENT mindspore
|
|
||||||
)
|
|
||||||
|
|
||||||
install(
|
|
||||||
TARGETS inference
|
|
||||||
DESTINATION ${INSTALL_LIB_DIR}
|
|
||||||
COMPONENT mindspore
|
|
||||||
)
|
|
||||||
|
|
||||||
install(
|
|
||||||
FILES ${LIBEVENT_LIB_LIST}
|
|
||||||
DESTINATION ${INSTALL_LIB_DIR}
|
|
||||||
COMPONENT mindspore
|
|
||||||
)
|
|
||||||
endif ()
|
|
||||||
|
|
|
@ -350,44 +350,9 @@ if (ENABLE_MINDDATA)
|
||||||
add_subdirectory(minddata/dataset)
|
add_subdirectory(minddata/dataset)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
# build inference
|
|
||||||
set(LOAD_MINDIR_SRC
|
|
||||||
${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/load_model.cc
|
|
||||||
${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/anf_model_parser.cc
|
|
||||||
)
|
|
||||||
add_library(inference SHARED
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/session/infer_session.cc
|
|
||||||
${LOAD_MINDIR_SRC}
|
|
||||||
)
|
|
||||||
|
|
||||||
set_target_properties(inference PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH})
|
|
||||||
|
|
||||||
if (CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
|
||||||
target_link_libraries(inference PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
|
|
||||||
-Wl,-force_load mindspore proto_input -Wl,-noall_load mindspore_gvar)
|
|
||||||
else()
|
|
||||||
target_link_libraries(inference PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
|
|
||||||
-Wl,--whole-archive mindspore proto_input -Wl,--no-whole-archive mindspore_gvar)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (ENABLE_D)
|
if (ENABLE_D)
|
||||||
find_library(adump_server libadump_server.a ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
find_library(adump_server libadump_server.a ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||||
target_link_libraries(_c_expression PRIVATE ${adump_server})
|
target_link_libraries(_c_expression PRIVATE ${adump_server})
|
||||||
target_link_libraries(inference PRIVATE ${adump_server} ms_profile)
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (ENABLE_CPU)
|
|
||||||
target_link_libraries(inference PRIVATE mindspore::dnnl mindspore::mkldnn)
|
|
||||||
endif ()
|
|
||||||
|
|
||||||
if (USE_GLOG)
|
|
||||||
target_link_libraries(inference PRIVATE mindspore::glog)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
|
||||||
target_link_options(inference PRIVATE -Wl,-init,common_log_init)
|
|
||||||
elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
|
||||||
set_target_properties(inference PROPERTIES MACOSX_RPATH ON)
|
|
||||||
endif ()
|
|
||||||
|
|
||||||
add_subdirectory(cxx_api)
|
add_subdirectory(cxx_api)
|
||||||
|
|
|
@ -1,341 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "backend/session/infer_session.h"
|
|
||||||
#include <memory>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <fstream>
|
|
||||||
|
|
||||||
#include "include/inference.h"
|
|
||||||
#include "backend/session/session_basic.h"
|
|
||||||
#include "backend/session/session_factory.h"
|
|
||||||
#include "backend/session/executor_manager.h"
|
|
||||||
#include "base/base_ref_utils.h"
|
|
||||||
#include "load_mindir/load_model.h"
|
|
||||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
|
||||||
#include "utils/context/context_extends.h"
|
|
||||||
#include "runtime/device/kernel_runtime_manager.h"
|
|
||||||
|
|
||||||
#include "pybind11/pybind11.h"
|
|
||||||
|
|
||||||
#ifdef ENABLE_D
|
|
||||||
#include "utils/ms_context.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
using std::string;
|
|
||||||
using std::vector;
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
namespace mindspore {
|
|
||||||
namespace inference {
|
|
||||||
std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &device, uint32_t device_id) {
|
|
||||||
try {
|
|
||||||
auto session = std::make_shared<MSInferSession>();
|
|
||||||
Status ret = session->InitEnv(device, device_id);
|
|
||||||
if (ret != SUCCESS) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return session;
|
|
||||||
} catch (std::bad_alloc &e) {
|
|
||||||
MS_LOG(ERROR) << "Inference CreatSession failed, failed to alloc memory";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
MSInferSession::MSInferSession() = default;
|
|
||||||
MSInferSession::~MSInferSession() = default;
|
|
||||||
|
|
||||||
Status MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
|
|
||||||
Py_Initialize();
|
|
||||||
auto graph = mindspore::LoadMindIR(file_name);
|
|
||||||
if (graph == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
Status ret = CompileGraph(graph, model_id);
|
|
||||||
if (ret != SUCCESS) {
|
|
||||||
MS_LOG(ERROR) << "Compile graph model failed, file name is " << file_name.c_str();
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
MS_LOG(INFO) << "Load model from file " << file_name << " success";
|
|
||||||
|
|
||||||
#ifdef ENABLE_D
|
|
||||||
// set d context
|
|
||||||
rtError_t rt_ret = rtCtxGetCurrent(&context_);
|
|
||||||
if (rt_ret != RT_ERROR_NONE || context_ == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "the ascend device context is null";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status MSInferSession::UnloadModel(uint32_t model_id) { return SUCCESS; }
|
|
||||||
|
|
||||||
Status ServingTensor2MSTensor(size_t index, const InferTensorBase &out_tensor, tensor::TensorPtr &ms_tensor) {
|
|
||||||
std::vector<int64_t> shape = out_tensor.shape();
|
|
||||||
TypeId data_type;
|
|
||||||
const std::map<inference::DataType, TypeId> type2id_map{
|
|
||||||
{inference::kMSI_Unknown, TypeId::kNumberTypeBegin}, {inference::kMSI_Bool, TypeId::kNumberTypeBool},
|
|
||||||
{inference::kMSI_Int8, TypeId::kNumberTypeInt8}, {inference::kMSI_Uint8, TypeId::kNumberTypeUInt8},
|
|
||||||
{inference::kMSI_Int16, TypeId::kNumberTypeInt16}, {inference::kMSI_Uint16, TypeId::kNumberTypeUInt16},
|
|
||||||
{inference::kMSI_Int32, TypeId::kNumberTypeInt32}, {inference::kMSI_Uint32, TypeId::kNumberTypeUInt32},
|
|
||||||
{inference::kMSI_Int64, TypeId::kNumberTypeInt64}, {inference::kMSI_Uint64, TypeId::kNumberTypeUInt64},
|
|
||||||
{inference::kMSI_Float16, TypeId::kNumberTypeFloat16}, {inference::kMSI_Float32, TypeId::kNumberTypeFloat32},
|
|
||||||
{inference::kMSI_Float64, TypeId::kNumberTypeFloat64},
|
|
||||||
};
|
|
||||||
auto it = type2id_map.find(out_tensor.data_type());
|
|
||||||
if (it == type2id_map.end()) {
|
|
||||||
MSI_LOG_WARNING << "undefined MSI data type " << out_tensor.data_type();
|
|
||||||
return FAILED;
|
|
||||||
} else {
|
|
||||||
data_type = it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
ms_tensor = std::make_shared<tensor::Tensor>(data_type, shape);
|
|
||||||
if (out_tensor.data_size() == 0 || ms_tensor->Size() != out_tensor.data_size()) {
|
|
||||||
MSI_LOG_ERROR << "input " << std::to_string(index)
|
|
||||||
<< " data size not match shape and dtype, calculated required size " << ms_tensor->Size()
|
|
||||||
<< ", given " << out_tensor.data_size();
|
|
||||||
return INFER_STATUS(INVALID_INPUTS) << "input " << std::to_string(index)
|
|
||||||
<< " data size not match shape and dtype, calculated required size "
|
|
||||||
<< ms_tensor->Size() << ", given " << out_tensor.data_size();
|
|
||||||
}
|
|
||||||
if (out_tensor.data() == nullptr || ms_tensor->data_c() == nullptr) {
|
|
||||||
MSI_LOG_ERROR << "invalid data buffer";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
auto ret_code = memcpy_s(ms_tensor->data_c(), ms_tensor->Size(), out_tensor.data(), out_tensor.data_size());
|
|
||||||
if (ret_code != 0) {
|
|
||||||
MS_LOG(ERROR) << "Failed to copy data from ms_tensor to out_tensor.";
|
|
||||||
}
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
void MSTensor2ServingTensor(tensor::TensorPtr ms_tensor, InferTensorBase &out_tensor) {
|
|
||||||
vector<int64_t> shape = ms_tensor->shape();
|
|
||||||
out_tensor.set_shape(shape);
|
|
||||||
|
|
||||||
const std::map<TypeId, inference::DataType> id2type_map{
|
|
||||||
{TypeId::kNumberTypeBegin, inference::kMSI_Unknown}, {TypeId::kNumberTypeBool, inference::kMSI_Bool},
|
|
||||||
{TypeId::kNumberTypeFloat64, inference::kMSI_Float64}, {TypeId::kNumberTypeInt8, inference::kMSI_Int8},
|
|
||||||
{TypeId::kNumberTypeUInt8, inference::kMSI_Uint8}, {TypeId::kNumberTypeInt16, inference::kMSI_Int16},
|
|
||||||
{TypeId::kNumberTypeUInt16, inference::kMSI_Uint16}, {TypeId::kNumberTypeInt32, inference::kMSI_Int32},
|
|
||||||
{TypeId::kNumberTypeUInt32, inference::kMSI_Uint32}, {TypeId::kNumberTypeInt64, inference::kMSI_Int64},
|
|
||||||
{TypeId::kNumberTypeUInt64, inference::kMSI_Uint64}, {TypeId::kNumberTypeFloat16, inference::kMSI_Float16},
|
|
||||||
{TypeId::kNumberTypeFloat32, inference::kMSI_Float32},
|
|
||||||
};
|
|
||||||
auto it = id2type_map.find(ms_tensor->data_type());
|
|
||||||
if (it == id2type_map.end()) {
|
|
||||||
MSI_LOG_WARNING << "undefined MS data type " << ms_tensor->data_type();
|
|
||||||
out_tensor.set_data_type(inference::kMSI_Unknown);
|
|
||||||
} else {
|
|
||||||
out_tensor.set_data_type(it->second);
|
|
||||||
}
|
|
||||||
out_tensor.set_data(ms_tensor->data_c(), ms_tensor->Size());
|
|
||||||
}
|
|
||||||
|
|
||||||
Status MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) {
|
|
||||||
#ifdef ENABLE_D
|
|
||||||
if (context_ == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "rtCtx is nullptr";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
rtError_t rt_ret = rtCtxSetCurrent(context_);
|
|
||||||
if (rt_ret != RT_ERROR_NONE) {
|
|
||||||
MS_LOG(ERROR) << "set Ascend rtCtx failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
vector<tensor::TensorPtr> inputs;
|
|
||||||
for (size_t i = 0; i < request.size(); i++) {
|
|
||||||
if (request[i] == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed, input tensor is null, index " << i;
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
tensor::TensorPtr input = nullptr;
|
|
||||||
auto ret = ServingTensor2MSTensor(i, *request[i], input);
|
|
||||||
if (ret != SUCCESS) {
|
|
||||||
MS_LOG(ERROR) << "Tensor convert failed";
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
inputs.push_back(input);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto ret = CheckModelInputs(model_id, inputs);
|
|
||||||
if (ret != SUCCESS) {
|
|
||||||
MS_LOG(ERROR) << "Check Model " << model_id << " Inputs Failed";
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
vector<tensor::TensorPtr> outputs = RunGraph(model_id, inputs);
|
|
||||||
if (outputs.empty()) {
|
|
||||||
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
reply.clear();
|
|
||||||
for (const auto &tensor : outputs) {
|
|
||||||
auto out_tensor = reply.add();
|
|
||||||
if (out_tensor == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed add output tensor failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
MSTensor2ServingTensor(tensor, *out_tensor);
|
|
||||||
}
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status MSInferSession::FinalizeEnv() {
|
|
||||||
session::ExecutorManager::Instance().Clear();
|
|
||||||
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
|
|
||||||
auto ms_context = MsContext::GetInstance();
|
|
||||||
if (ms_context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Get Context failed!";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
if (!context::CloseTsd(ms_context)) {
|
|
||||||
MS_LOG(ERROR) << "Inference CloseTsd failed!";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
void MSInferSession::RegAllOp() {
|
|
||||||
static std::mutex init_mutex;
|
|
||||||
static bool Initialized = false;
|
|
||||||
|
|
||||||
std::lock_guard<std::mutex> lock(init_mutex);
|
|
||||||
if (Initialized) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Initialized = true;
|
|
||||||
MsContext::GetInstance()->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
|
||||||
Py_Initialize();
|
|
||||||
auto c_expression = PyImport_ImportModule("mindspore._c_expression");
|
|
||||||
MS_EXCEPTION_IF_NULL(c_expression);
|
|
||||||
PyObject *c_expression_dict = PyModule_GetDict(c_expression);
|
|
||||||
MS_EXCEPTION_IF_NULL(c_expression_dict);
|
|
||||||
|
|
||||||
PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy");
|
|
||||||
MS_EXCEPTION_IF_NULL(op_info_loader_class);
|
|
||||||
PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class);
|
|
||||||
MS_EXCEPTION_IF_NULL(op_info_loader);
|
|
||||||
PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr);
|
|
||||||
MS_EXCEPTION_IF_NULL(op_info_loader_ins);
|
|
||||||
auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr);
|
|
||||||
MS_EXCEPTION_IF_NULL(all_ops_info_vector_addr_ul);
|
|
||||||
auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul);
|
|
||||||
auto all_ops_info = static_cast<std::vector<kernel::OpInfo *> *>(all_ops_info_vector_addr);
|
|
||||||
for (auto op_info : *all_ops_info) {
|
|
||||||
kernel::OpLib::RegOpInfo(std::shared_ptr<kernel::OpInfo>(op_info));
|
|
||||||
}
|
|
||||||
all_ops_info->clear();
|
|
||||||
delete all_ops_info;
|
|
||||||
Py_DECREF(op_info_loader);
|
|
||||||
Py_DECREF(op_info_loader_class);
|
|
||||||
Py_DECREF(c_expression_dict);
|
|
||||||
Py_DECREF(c_expression);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status MSInferSession::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id) {
|
|
||||||
MS_ASSERT(session_impl_ != nullptr);
|
|
||||||
try {
|
|
||||||
auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
|
|
||||||
py::gil_scoped_release gil_release;
|
|
||||||
model_id = graph_id;
|
|
||||||
return SUCCESS;
|
|
||||||
} catch (std::exception &e) {
|
|
||||||
MS_LOG(ERROR) << "Inference CompileGraph failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<tensor::TensorPtr> MSInferSession::RunGraph(uint32_t graph_id,
|
|
||||||
const std::vector<tensor::TensorPtr> &inputs) {
|
|
||||||
try {
|
|
||||||
VectorRef outputs;
|
|
||||||
session_impl_->RunGraph(graph_id, inputs, &outputs);
|
|
||||||
return TransformVectorRefToMultiTensor(outputs);
|
|
||||||
} catch (std::exception &e) {
|
|
||||||
MS_LOG(ERROR) << "Inference Rungraph failed";
|
|
||||||
return std::vector<tensor::TensorPtr>();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
string MSInferSession::AjustTargetName(const std::string &device) {
|
|
||||||
if (device == kAscendDevice) {
|
|
||||||
return std::string(kAscendDevice) + "Inference";
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "Only support device Ascend right now";
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Status MSInferSession::InitEnv(const std::string &device, uint32_t device_id) {
|
|
||||||
RegAllOp();
|
|
||||||
auto ms_context = MsContext::GetInstance();
|
|
||||||
if (ms_context == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Get Context failed!";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
|
||||||
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
|
|
||||||
auto ajust_device = AjustTargetName(device);
|
|
||||||
if (ajust_device == "") {
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, device);
|
|
||||||
if (!context::OpenTsd(ms_context)) {
|
|
||||||
MS_LOG(ERROR) << "Session init OpenTsd failed!";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
session_impl_ = session::SessionFactory::Get().Create(ajust_device);
|
|
||||||
if (session_impl_ == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available.";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
session_impl_->Init(device_id);
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status MSInferSession::CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const {
|
|
||||||
MS_ASSERT(session_impl_ != nullptr);
|
|
||||||
std::string error_msg;
|
|
||||||
if (!session_impl_->CheckModelInputs(graph_id, inputs, &error_msg)) {
|
|
||||||
return INFER_STATUS(INVALID_INPUTS) << error_msg;
|
|
||||||
}
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status MSInferSession::GetModelInputsInfo(uint32_t model_id, std::vector<inference::InferTensor> *tensor_list) const {
|
|
||||||
vector<tensor::TensorPtr> inputs;
|
|
||||||
vector<std::string> input_names;
|
|
||||||
session_impl_->GetModelInputsInfo(model_id, &inputs, &input_names);
|
|
||||||
if (inputs.size() == 0) {
|
|
||||||
MS_LOG(ERROR) << "The model inputs is NULL";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
for (const auto &tensor : inputs) {
|
|
||||||
InferTensor infer_tensor = InferTensor();
|
|
||||||
MSTensor2ServingTensor(tensor, infer_tensor);
|
|
||||||
tensor_list->push_back(infer_tensor);
|
|
||||||
}
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
} // namespace inference
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,65 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_CCSRC_SESSION_SESSION_H
|
|
||||||
#define MINDSPORE_CCSRC_SESSION_SESSION_H
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <utility>
|
|
||||||
#include <memory>
|
|
||||||
#include <map>
|
|
||||||
|
|
||||||
#include "backend/session/session_basic.h"
|
|
||||||
#include "ir/anf.h"
|
|
||||||
#include "include/inference.h"
|
|
||||||
|
|
||||||
#ifdef ENABLE_D
|
|
||||||
#include "runtime/context.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace inference {
|
|
||||||
class MSInferSession : public InferSession {
|
|
||||||
public:
|
|
||||||
MSInferSession();
|
|
||||||
~MSInferSession();
|
|
||||||
|
|
||||||
Status InitEnv(const std::string &device_type, uint32_t device_id) override;
|
|
||||||
Status FinalizeEnv() override;
|
|
||||||
Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override;
|
|
||||||
Status UnloadModel(uint32_t model_id) override;
|
|
||||||
Status ExecuteModel(uint32_t model_id, const RequestBase &inputs, ReplyBase &outputs) override;
|
|
||||||
Status GetModelInputsInfo(uint32_t graph_id, std::vector<inference::InferTensor> *tensor_list) const override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::shared_ptr<session::SessionBasic> session_impl_ = nullptr;
|
|
||||||
std::vector<uint32_t> graph_id_;
|
|
||||||
std::string device_type_;
|
|
||||||
int32_t device_id_ = 0;
|
|
||||||
#ifdef ENABLE_D
|
|
||||||
rtContext_t context_ = nullptr;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
static void RegAllOp();
|
|
||||||
string AjustTargetName(const std::string &device);
|
|
||||||
Status CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id);
|
|
||||||
Status CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const;
|
|
||||||
std::vector<tensor::TensorPtr> RunGraph(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs);
|
|
||||||
};
|
|
||||||
} // namespace inference
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H
|
|
|
@ -1,125 +0,0 @@
|
||||||
find_package(Threads REQUIRED)
|
|
||||||
|
|
||||||
# This branch assumes that gRPC and all its dependencies are already installed
|
|
||||||
# on this system, so they can be located by find_package().
|
|
||||||
|
|
||||||
# Find Protobuf installation
|
|
||||||
# Looks for protobuf-config.cmake file installed by Protobuf's cmake installation.
|
|
||||||
|
|
||||||
#set(protobuf_MODULE_COMPATIBLE TRUE)
|
|
||||||
#find_package(Protobuf CONFIG REQUIRED)
|
|
||||||
#message(STATUS "Using protobuf ${protobuf_VERSION}")
|
|
||||||
add_library(protobuf::libprotobuf ALIAS protobuf::protobuf)
|
|
||||||
add_executable(protobuf::libprotoc ALIAS protobuf::protoc)
|
|
||||||
|
|
||||||
set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf)
|
|
||||||
if (CMAKE_CROSSCOMPILING)
|
|
||||||
find_program(_PROTOBUF_PROTOC protoc)
|
|
||||||
else ()
|
|
||||||
set(_PROTOBUF_PROTOC $<TARGET_FILE:protobuf::protoc>)
|
|
||||||
endif ()
|
|
||||||
|
|
||||||
# Find gRPC installation
|
|
||||||
# Looks for gRPCConfig.cmake file installed by gRPC's cmake installation.
|
|
||||||
if (EXISTS ${grpc_ROOT}/lib64)
|
|
||||||
set(gRPC_DIR "${grpc_ROOT}/lib64/cmake/grpc")
|
|
||||||
else ()
|
|
||||||
set(gRPC_DIR "${grpc_ROOT}/lib/cmake/grpc")
|
|
||||||
endif ()
|
|
||||||
message("serving using grpc_DIR : " ${gPRC_DIR})
|
|
||||||
|
|
||||||
find_package(gRPC CONFIG REQUIRED)
|
|
||||||
message(STATUS "Using gRPC ${gRPC_VERSION}")
|
|
||||||
|
|
||||||
set(_GRPC_GRPCPP gRPC::grpc++)
|
|
||||||
set(_REFLECTION gRPC::grpc++_reflection)
|
|
||||||
|
|
||||||
if (CMAKE_CROSSCOMPILING)
|
|
||||||
find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin)
|
|
||||||
find_program(_GRPC_PYTHON_PLUGIN_EXECUTABLE grpc_python_plugin)
|
|
||||||
else ()
|
|
||||||
set(_GRPC_CPP_PLUGIN_EXECUTABLE $<TARGET_FILE:gRPC::grpc_cpp_plugin>)
|
|
||||||
set(_GRPC_PYTHON_PLUGIN_EXECUTABLE $<TARGET_FILE:gRPC::grpc_python_plugin>)
|
|
||||||
endif ()
|
|
||||||
|
|
||||||
# Proto file
|
|
||||||
get_filename_component(hw_proto "ms_service.proto" ABSOLUTE)
|
|
||||||
get_filename_component(hw_proto_path "${hw_proto}" PATH)
|
|
||||||
# Generated sources
|
|
||||||
set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.cc")
|
|
||||||
set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.h")
|
|
||||||
set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.cc")
|
|
||||||
set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.h")
|
|
||||||
set(hw_py_pb2 "${CMAKE_CURRENT_BINARY_DIR}/ms_service_pb2.py")
|
|
||||||
set(hw_py_pb2_grpc "${CMAKE_CURRENT_BINARY_DIR}/ms_service_pb2_grpc.py")
|
|
||||||
add_custom_command(
|
|
||||||
OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}" "${hw_py_pb2}" "${hw_py_pb2_grpc}"
|
|
||||||
COMMAND ${_PROTOBUF_PROTOC}
|
|
||||||
ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}"
|
|
||||||
--cpp_out "${CMAKE_CURRENT_BINARY_DIR}"
|
|
||||||
-I "${hw_proto_path}"
|
|
||||||
--plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}"
|
|
||||||
"${hw_proto}"
|
|
||||||
COMMAND ${_PROTOBUF_PROTOC}
|
|
||||||
ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}"
|
|
||||||
--python_out "${CMAKE_CURRENT_BINARY_DIR}"
|
|
||||||
-I "${hw_proto_path}"
|
|
||||||
--plugin=protoc-gen-grpc="${_GRPC_PYTHON_PLUGIN_EXECUTABLE}"
|
|
||||||
"${hw_proto}"
|
|
||||||
DEPENDS "${hw_proto}")
|
|
||||||
|
|
||||||
# Include generated *.pb.h files
|
|
||||||
include_directories("${CMAKE_CURRENT_BINARY_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}/core"
|
|
||||||
"${PROJECT_SOURCE_DIR}/mindspore/ccsrc" "${PROJECT_SOURCE_DIR}/mindspore/core")
|
|
||||||
file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|
||||||
"core/*.cc" "core/util/*.cc" "core/version_control/*.cc")
|
|
||||||
|
|
||||||
list(APPEND SERVING_SRC "main.cc" ${hw_proto_srcs} ${hw_grpc_srcs} ${CORE_SRC_LIST})
|
|
||||||
|
|
||||||
if (ENABLE_ACL)
|
|
||||||
if (DEFINED ENV{ASCEND_CUSTOM_PATH})
|
|
||||||
set(ASCEND_PATH $ENV{ASCEND_CUSTOM_PATH})
|
|
||||||
else ()
|
|
||||||
set(ASCEND_PATH /usr/local/Ascend)
|
|
||||||
endif ()
|
|
||||||
set(ACL_LIB_DIR ${ASCEND_PATH}/acllib/)
|
|
||||||
set(ATLAS_ACL_LIB_DIR ${ASCEND_PATH}/ascend-toolkit/latest/acllib)
|
|
||||||
MESSAGE("hisi acl lib dir " ${ACL_LIB_DIR} " ,atlas acl lib dir " ${ATLAS_ACL_LIB_DIR})
|
|
||||||
|
|
||||||
include_directories(${ACL_LIB_DIR}/include/)
|
|
||||||
include_directories(${ATLAS_ACL_LIB_DIR}/include/)
|
|
||||||
file(GLOB_RECURSE ACL_SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "acl/*.cc")
|
|
||||||
list(APPEND SERVING_SRC ${ACL_SESSION_SRC_LIST})
|
|
||||||
endif ()
|
|
||||||
|
|
||||||
include_directories(${CMAKE_BINARY_DIR})
|
|
||||||
|
|
||||||
add_executable(ms_serving ${SERVING_SRC})
|
|
||||||
#libevent
|
|
||||||
target_link_libraries(ms_serving mindspore::event mindspore::event_pthreads)
|
|
||||||
|
|
||||||
target_link_libraries(ms_serving ${_REFLECTION} ${_GRPC_GRPCPP} ${_PROTOBUF_LIBPROTOBUF} pthread)
|
|
||||||
|
|
||||||
include(CheckPIESupported)
|
|
||||||
check_pie_supported()
|
|
||||||
set_property(TARGET ms_serving PROPERTY POSITION_INDEPENDENT_CODE TRUE)
|
|
||||||
|
|
||||||
if (ENABLE_D)
|
|
||||||
add_compile_definitions(ENABLE_D)
|
|
||||||
target_link_libraries(ms_serving ${RUNTIME_LIB})
|
|
||||||
endif ()
|
|
||||||
|
|
||||||
if (ENABLE_ACL)
|
|
||||||
add_compile_definitions(ENABLE_ACL)
|
|
||||||
add_compile_definitions(ENABLE_DVPP_INTERFACE)
|
|
||||||
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
|
||||||
find_library(acl_retr libacl_retr.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
|
||||||
find_library(acl_cblas libacl_cblas.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
|
||||||
find_library(acl_dvpp libacl_dvpp.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
|
||||||
find_library(acl_runtime libruntime.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
|
||||||
|
|
||||||
target_link_libraries(ms_serving ${acl} ${acl_retr} ${acl_cblas} ${acl_dvpp} ${acl_runtime})
|
|
||||||
target_link_libraries(ms_serving jpeg_turbo::jpeg securec)
|
|
||||||
else ()
|
|
||||||
target_link_libraries(ms_serving inference mindspore_gvar)
|
|
||||||
endif ()
|
|
|
@ -1,150 +0,0 @@
|
||||||
# MindSpore-based Inference Service Deployment
|
|
||||||
|
|
||||||
|
|
||||||
<!-- TOC -->
|
|
||||||
|
|
||||||
- [MindSpore-based Inference Service Deployment](#mindspore-based-inference-service-deployment)
|
|
||||||
- [Overview](#overview)
|
|
||||||
- [Starting Serving](#starting-serving)
|
|
||||||
- [Application Example](#application-example)
|
|
||||||
- [Exporting Model](#exporting-model)
|
|
||||||
- [Starting Serving Inference](#starting-serving-inference)
|
|
||||||
- [Client Samples](#client-samples)
|
|
||||||
- [Python Client Sample](#python-client-sample)
|
|
||||||
- [C++ Client Sample](#cpp-client-sample)
|
|
||||||
|
|
||||||
<!-- /TOC -->
|
|
||||||
<a href="https://gitee.com/mindspore/docs/blob/master/tutorials/source_en/advanced_use/serving.md" target="_blank"><img src="../_static/logo_source.png"></a>
|
|
||||||
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
MindSpore Serving is a lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment. After completing model training using MindSpore, you can export the MindSpore model and use MindSpore Serving to create an inference service for the model. Currently, only Ascend 910 is supported.
|
|
||||||
|
|
||||||
|
|
||||||
## Starting Serving
|
|
||||||
After MindSpore is installed using `pip`, the Serving executable program is stored in `/{your python path}/lib/python3.7/site-packages/mindspore/ms_serving`.
|
|
||||||
Run the following command to start Serving:
|
|
||||||
```bash
|
|
||||||
ms_serving [--help] [--model_path <MODEL_PATH>] [--model_name <MODEL_NAME>]
|
|
||||||
[--port <PORT>] [--device_id <DEVICE_ID>]
|
|
||||||
```
|
|
||||||
Parameters are described as follows:
|
|
||||||
|
|
||||||
|Parameter|Attribute|Function|Parameter Type|Default Value|Value Range|
|
|
||||||
|---|---|---|---|---|---|
|
|
||||||
|`--help`|Optional|Displays the help information about the startup command. |-|-|-|
|
|
||||||
|`--model_path=<MODEL_PATH>`|Mandatory|Path for storing the model to be loaded. |String|Null|-|
|
|
||||||
|`--model_name=<MODEL_NAME>`|Mandatory|Name of the model file to be loaded. |String|Null|-|
|
|
||||||
|`--=port <PORT>`|Optional|Specifies the external Serving port number. |Integer|5500|1–65535|
|
|
||||||
|`--device_id=<DEVICE_ID>`|Optional|Specifies device ID to be used.|Integer|0|0 to 7|
|
|
||||||
|
|
||||||
> Before running the startup command, add the path `/{your python path}/lib:/{your python path}/lib/python3.7/site-packages/mindspore/lib` to the environment variable `LD_LIBRARY_PATH`.
|
|
||||||
|
|
||||||
## Application Example
|
|
||||||
The following uses a simple network as an example to describe how to use MindSpore Serving.
|
|
||||||
|
|
||||||
### Exporting Model
|
|
||||||
Use [add_model.py](https://gitee.com/mindspore/mindspore/blob/master/serving/example/export_model/add_model.py) to build a network with only the Add operator and export the MindSpore inference deployment model.
|
|
||||||
|
|
||||||
```python
|
|
||||||
python add_model.py
|
|
||||||
```
|
|
||||||
Execute the script to generate the `tensor_add.mindir` file. The input of the model is two one-dimensional tensors with shape [2,2], and the output is the sum of the two input tensors.
|
|
||||||
|
|
||||||
### Starting Serving Inference
|
|
||||||
```bash
|
|
||||||
ms_serving --model_path={model directory} --model_name=tensor_add.mindir
|
|
||||||
```
|
|
||||||
If the server prints the `MS Serving Listening on 0.0.0.0:5500` log, the Serving has loaded the inference model.
|
|
||||||
|
|
||||||
### Client Samples
|
|
||||||
#### <span name="python-client-sample">Python Client Sample</span>
|
|
||||||
Obtain [ms_client.py](https://gitee.com/mindspore/mindspore/blob/master/serving/example/python_client/ms_client.py) and start the Python client.
|
|
||||||
```bash
|
|
||||||
python ms_client.py
|
|
||||||
```
|
|
||||||
|
|
||||||
If the following information is displayed, the Serving has correctly executed the inference of the Add network.
|
|
||||||
```
|
|
||||||
ms client received:
|
|
||||||
[[2. 2.]
|
|
||||||
[2. 2.]]
|
|
||||||
```
|
|
||||||
|
|
||||||
#### <span name="cpp-client-sample">C++ Client Sample</span>
|
|
||||||
1. Obtain an executable client sample program.
|
|
||||||
|
|
||||||
Download the [MindSpore source code](https://gitee.com/mindspore/mindspore). You can use either of the following methods to compile and obtain the client sample program:
|
|
||||||
+ When MindSpore is compiled using the source code, the Serving C++ client sample program is generated. You can find the `ms_client` executable program in the `build/mindspore/serving/example/cpp_client` directory.
|
|
||||||
+ Independent compilation
|
|
||||||
|
|
||||||
Preinstall [gRPC](https://gRPC.io).
|
|
||||||
|
|
||||||
Run the following command in the MindSpore source code path to compile a client sample program:
|
|
||||||
```bash
|
|
||||||
cd mindspore/serving/example/cpp_client
|
|
||||||
mkdir build && cd build
|
|
||||||
cmake -D GRPC_PATH={grpc_install_dir} ..
|
|
||||||
make
|
|
||||||
```
|
|
||||||
In the preceding command, `{grpc_install_dir}` indicates the gRPC installation path. Replace it with the actual gRPC installation path.
|
|
||||||
|
|
||||||
2. Start the client.
|
|
||||||
|
|
||||||
Execute `ms_client` to send an inference request to the Serving.
|
|
||||||
```bash
|
|
||||||
./ms_client --target=localhost:5500
|
|
||||||
```
|
|
||||||
If the following information is displayed, the Serving has correctly executed the inference of the Add network.
|
|
||||||
```
|
|
||||||
Compute [[1, 2], [3, 4]] + [[1, 2], [3, 4]]
|
|
||||||
Add result is 2 4 6 8
|
|
||||||
client received: RPC OK
|
|
||||||
```
|
|
||||||
|
|
||||||
The client code consists of the following parts:
|
|
||||||
|
|
||||||
1. Implement the client based on MSService::Stub and create a client instance.
|
|
||||||
```
|
|
||||||
class MSClient {
|
|
||||||
public:
|
|
||||||
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
|
|
||||||
private:
|
|
||||||
std::unique_ptr<MSService::Stub> stub_;
|
|
||||||
};MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()));
|
|
||||||
|
|
||||||
MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()));
|
|
||||||
|
|
||||||
```
|
|
||||||
2. Build the request input parameter `Request`, output parameter `Reply`, and gRPC client `Context` based on the actual network input.
|
|
||||||
```
|
|
||||||
PredictRequest request;
|
|
||||||
PredictReply reply;
|
|
||||||
ClientContext context;
|
|
||||||
|
|
||||||
//construct tensor
|
|
||||||
Tensor data;
|
|
||||||
|
|
||||||
//set shape
|
|
||||||
TensorShape shape;
|
|
||||||
shape.add_dims(4);
|
|
||||||
*data.mutable_tensor_shape() = shape;
|
|
||||||
|
|
||||||
//set type
|
|
||||||
data.set_tensor_type(ms_serving::MS_FLOAT32);
|
|
||||||
std::vector<float> input_data{1, 2, 3, 4};
|
|
||||||
|
|
||||||
//set datas
|
|
||||||
data.set_data(input_data.data(), input_data.size());
|
|
||||||
|
|
||||||
//add tensor to request
|
|
||||||
*request.add_data() = data;
|
|
||||||
*request.add_data() = data;
|
|
||||||
```
|
|
||||||
3. Call the gRPC API to communicate with the Serving that has been started, and obtain the return value.
|
|
||||||
```
|
|
||||||
Status status = stub_->Predict(&context, request, &reply);
|
|
||||||
```
|
|
||||||
|
|
||||||
For details about the complete code, see [ms_client](https://gitee.com/mindspore/mindspore/blob/master/serving/example/cpp_client/ms_client.cc).
|
|
|
@ -1,151 +0,0 @@
|
||||||
# 基于MindSpore部署推理服务
|
|
||||||
|
|
||||||
|
|
||||||
<!-- TOC -->
|
|
||||||
|
|
||||||
- [基于MindSpore部署推理服务](#基于mindspore部署推理服务)
|
|
||||||
- [概述](#概述)
|
|
||||||
- [启动Serving服务](#启动serving服务)
|
|
||||||
- [应用示例](#应用示例)
|
|
||||||
- [导出模型](#导出模型)
|
|
||||||
- [启动Serving推理服务](#启动serving推理服务)
|
|
||||||
- [客户端示例](#客户端示例)
|
|
||||||
- [Python客户端示例](#python客户端示例)
|
|
||||||
- [C++客户端示例](#cpp客户端示例)
|
|
||||||
|
|
||||||
<!-- /TOC -->
|
|
||||||
<a href="https://gitee.com/mindspore/docs/blob/master/tutorials/source_zh_cn/advanced_use/serving.md" target="_blank"><img src="../_static/logo_source.png"></a>
|
|
||||||
|
|
||||||
|
|
||||||
## 概述
|
|
||||||
|
|
||||||
MindSpore Serving是一个轻量级、高性能的服务模块,旨在帮助MindSpore开发者在生产环境中高效部署在线推理服务。当用户使用MindSpore完成模型训练后,导出MindSpore模型,即可使用MindSpore Serving创建该模型的推理服务。当前Serving仅支持Ascend 910。
|
|
||||||
|
|
||||||
|
|
||||||
## 启动Serving服务
|
|
||||||
通过pip安装MindSpore后,Serving可执行程序位于`/{your python path}/lib/python3.7/site-packages/mindspore/ms_serving` 。
|
|
||||||
启动Serving服务命令如下
|
|
||||||
```bash
|
|
||||||
ms_serving [--help] [--model_path <MODEL_PATH>] [--model_name <MODEL_NAME>]
|
|
||||||
[--port <PORT>] [--device_id <DEVICE_ID>]
|
|
||||||
```
|
|
||||||
参数含义如下
|
|
||||||
|
|
||||||
|参数名|属性|功能描述|参数类型|默认值|取值范围|
|
|
||||||
|---|---|---|---|---|---|
|
|
||||||
|`--help`|可选|显示启动命令的帮助信息。|-|-|-|
|
|
||||||
|`--model_path=<MODEL_PATH>`|必选|指定待加载模型的存放路径。|String|空|-|
|
|
||||||
|`--model_name=<MODEL_NAME>`|必选|指定待加载模型的文件名。|String|空|-|
|
|
||||||
|`--port=<PORT>`|可选|指定Serving对外的端口号。|Integer|5500|1~65535|
|
|
||||||
|`--device_id=<DEVICE_ID>`|可选|指定使用的设备号|Integer|0|0~7|
|
|
||||||
|
|
||||||
> 执行启动命令前,需将`/{your python path}/lib:/{your python path}/lib/python3.7/site-packages/mindspore/lib`对应的路径加入到环境变量LD_LIBRARY_PATH中 。
|
|
||||||
|
|
||||||
## 应用示例
|
|
||||||
下面以一个简单的网络为例,演示MindSpore Serving如何使用。
|
|
||||||
|
|
||||||
### 导出模型
|
|
||||||
使用[add_model.py](https://gitee.com/mindspore/mindspore/blob/master/serving/example/export_model/add_model.py),构造一个只有Add算子的网络,并导出MindSpore推理部署模型。
|
|
||||||
|
|
||||||
```python
|
|
||||||
python add_model.py
|
|
||||||
```
|
|
||||||
执行脚本,生成`tensor_add.mindir`文件,该模型的输入为两个shape为[2,2]的二维Tensor,输出结果是两个输入Tensor之和。
|
|
||||||
|
|
||||||
### 启动Serving推理服务
|
|
||||||
```bash
|
|
||||||
ms_serving --model_path={model directory} --model_name=tensor_add.mindir
|
|
||||||
```
|
|
||||||
当服务端打印日志`MS Serving Listening on 0.0.0.0:5500`时,表示Serving服务已加载推理模型完毕。
|
|
||||||
|
|
||||||
### 客户端示例
|
|
||||||
#### <span name="python客户端示例">Python客户端示例</span>
|
|
||||||
获取[ms_client.py](https://gitee.com/mindspore/mindspore/blob/master/serving/example/python_client/ms_client.py),启动Python客户端。
|
|
||||||
```bash
|
|
||||||
python ms_client.py
|
|
||||||
```
|
|
||||||
|
|
||||||
显示如下返回值说明Serving服务已正确执行Add网络的推理。
|
|
||||||
```
|
|
||||||
ms client received:
|
|
||||||
[[2. 2.]
|
|
||||||
[2. 2.]]
|
|
||||||
```
|
|
||||||
|
|
||||||
#### <span name="cpp客户端示例">C++客户端示例</span>
|
|
||||||
1. 获取客户端示例执行程序
|
|
||||||
|
|
||||||
首先需要下载[MindSpore源码](https://gitee.com/mindspore/mindspore)。有两种方式编译并获取客户端示例程序:
|
|
||||||
+ 从源码编译MindSpore时候,将会编译产生Serving C++客户端示例程序,可在`build/mindspore/serving/example/cpp_client`目录下找到`ms_client`可执行程序。
|
|
||||||
+ 独立编译:
|
|
||||||
|
|
||||||
需要先预装[gRPC](https://gRPC.io)。
|
|
||||||
|
|
||||||
然后,在MindSpore源码路径中执行如下命令,编译一个客户端示例程序。
|
|
||||||
```bash
|
|
||||||
cd mindspore/serving/example/cpp_client
|
|
||||||
mkdir build && cd build
|
|
||||||
cmake -D GRPC_PATH={grpc_install_dir} ..
|
|
||||||
make
|
|
||||||
```
|
|
||||||
其中`{grpc_install_dir}`为gRPC安装时的路径,请替换为实际gRPC安装路径。
|
|
||||||
|
|
||||||
2. 启动客户端
|
|
||||||
|
|
||||||
执行ms_client,向Serving服务发送推理请求:
|
|
||||||
```bash
|
|
||||||
./ms_client --target=localhost:5500
|
|
||||||
```
|
|
||||||
显示如下返回值说明Serving服务已正确执行Add网络的推理。
|
|
||||||
```
|
|
||||||
Compute [[1, 2], [3, 4]] + [[1, 2], [3, 4]]
|
|
||||||
Add result is 2 4 6 8
|
|
||||||
client received: RPC OK
|
|
||||||
```
|
|
||||||
|
|
||||||
客户端代码主要包含以下几个部分:
|
|
||||||
|
|
||||||
1. 基于MSService::Stub实现Client,并创建Client实例。
|
|
||||||
```
|
|
||||||
class MSClient {
|
|
||||||
public:
|
|
||||||
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
|
|
||||||
private:
|
|
||||||
std::unique_ptr<MSService::Stub> stub_;
|
|
||||||
};MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()));
|
|
||||||
|
|
||||||
MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()));
|
|
||||||
|
|
||||||
```
|
|
||||||
2. 根据网络的实际输入构造请求的入参Request、出参Reply和gRPC的客户端Context。
|
|
||||||
```
|
|
||||||
PredictRequest request;
|
|
||||||
PredictReply reply;
|
|
||||||
ClientContext context;
|
|
||||||
|
|
||||||
//construct tensor
|
|
||||||
Tensor data;
|
|
||||||
|
|
||||||
//set shape
|
|
||||||
TensorShape shape;
|
|
||||||
shape.add_dims(4);
|
|
||||||
*data.mutable_tensor_shape() = shape;
|
|
||||||
|
|
||||||
//set type
|
|
||||||
data.set_tensor_type(ms_serving::MS_FLOAT32);
|
|
||||||
std::vector<float> input_data{1, 2, 3, 4};
|
|
||||||
|
|
||||||
//set datas
|
|
||||||
data.set_data(input_data.data(), input_data.size());
|
|
||||||
|
|
||||||
//add tensor to request
|
|
||||||
*request.add_data() = data;
|
|
||||||
*request.add_data() = data;
|
|
||||||
```
|
|
||||||
3. 调用gRPC接口和已经启动的Serving服务通信,并取回返回值。
|
|
||||||
```
|
|
||||||
Status status = stub_->Predict(&context, request, &reply);
|
|
||||||
```
|
|
||||||
|
|
||||||
完整代码参考[ms_client](https://gitee.com/mindspore/mindspore/blob/master/serving/example/cpp_client/ms_client.cc)。
|
|
||||||
|
|
|
@ -1,243 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <fstream>
|
|
||||||
#include "serving/acl/acl_session.h"
|
|
||||||
#include "include/infer_log.h"
|
|
||||||
|
|
||||||
namespace mindspore::inference {
|
|
||||||
|
|
||||||
std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &device, uint32_t device_id) {
|
|
||||||
try {
|
|
||||||
auto session = std::make_shared<AclSession>();
|
|
||||||
auto ret = session->InitEnv(device, device_id);
|
|
||||||
if (ret != SUCCESS) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return session;
|
|
||||||
} catch (std::exception &e) {
|
|
||||||
MSI_LOG_ERROR << "Inference CreatSession failed";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Status AclSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
|
|
||||||
Status ret = model_process_.LoadModelFromFile(file_name, model_id);
|
|
||||||
if (ret != SUCCESS) {
|
|
||||||
MSI_LOG_ERROR << "Load model from file failed, model file " << file_name;
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
std::string dvpp_config_file;
|
|
||||||
auto index = file_name.rfind(".");
|
|
||||||
if (index == std::string::npos) {
|
|
||||||
dvpp_config_file = file_name;
|
|
||||||
} else {
|
|
||||||
dvpp_config_file = file_name.substr(0, index);
|
|
||||||
}
|
|
||||||
dvpp_config_file += "_dvpp_config.json";
|
|
||||||
std::ifstream fp(dvpp_config_file);
|
|
||||||
if (!fp.is_open()) {
|
|
||||||
MSI_LOG_INFO << "Dvpp config file not exist, model will execute with tensors as inputs, dvpp config file "
|
|
||||||
<< dvpp_config_file;
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
fp.close();
|
|
||||||
if (dvpp_process_.InitWithJsonConfig(dvpp_config_file) != SUCCESS) {
|
|
||||||
MSI_LOG_ERROR << "Dvpp config file parse error, dvpp config file " << dvpp_config_file;
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
execute_with_dvpp_ = true;
|
|
||||||
MSI_LOG_INFO << "Dvpp config success";
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status AclSession::UnloadModel(uint32_t /*model_id*/) {
|
|
||||||
model_process_.UnLoad();
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status AclSession::ExecuteModel(uint32_t /*model_id*/, const RequestBase &request,
|
|
||||||
ReplyBase &reply) { // set d context
|
|
||||||
aclError rt_ret = aclrtSetCurrentContext(context_);
|
|
||||||
if (rt_ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "set the ascend device context failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
return model_process_.Execute(request, reply);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status AclSession::PreProcess(uint32_t /*model_id*/, const InferImagesBase *images_input,
|
|
||||||
ImagesDvppOutput &dvpp_output) {
|
|
||||||
if (images_input == nullptr) {
|
|
||||||
MSI_LOG_ERROR << "images input is nullptr";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
auto batch_size = images_input->batch_size();
|
|
||||||
if (batch_size <= 0) {
|
|
||||||
MSI_LOG_ERROR << "invalid batch size " << images_input->batch_size();
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
std::vector<const void *> pic_buffer_list;
|
|
||||||
std::vector<size_t> pic_size_list;
|
|
||||||
for (size_t i = 0; i < batch_size; i++) {
|
|
||||||
const void *pic_buffer = nullptr;
|
|
||||||
uint32_t pic_size = 0;
|
|
||||||
if (!images_input->get(i, pic_buffer, pic_size) || pic_buffer == nullptr || pic_size == 0) {
|
|
||||||
MSI_LOG_ERROR << "Get request " << 0 << "th buffer failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
pic_buffer_list.push_back(pic_buffer);
|
|
||||||
pic_size_list.push_back(pic_size);
|
|
||||||
}
|
|
||||||
auto ret = dvpp_process_.Process(pic_buffer_list, pic_size_list, dvpp_output.buffer_device, dvpp_output.buffer_size);
|
|
||||||
if (ret != SUCCESS) {
|
|
||||||
MSI_LOG_ERROR << "dvpp process failed";
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status AclSession::ExecuteModel(uint32_t model_id, const ImagesRequestBase &images_inputs, // images for preprocess
|
|
||||||
const RequestBase &request, ReplyBase &reply) {
|
|
||||||
if (!execute_with_dvpp_) {
|
|
||||||
MSI_LOG_ERROR << "Unexpected images as inputs, DVPP not config";
|
|
||||||
return INFER_STATUS(INVALID_INPUTS) << "Unexpected images as inputs, DVPP not config";
|
|
||||||
}
|
|
||||||
aclError rt_ret = aclrtSetCurrentContext(context_);
|
|
||||||
if (rt_ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "set the ascend device context failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
if (images_inputs.size() != 1) {
|
|
||||||
MSI_LOG_ERROR << "Only support one input to do DVPP preprocess";
|
|
||||||
return INFER_STATUS(INVALID_INPUTS) << "Only support one input to do DVPP preprocess";
|
|
||||||
}
|
|
||||||
if (images_inputs[0] == nullptr) {
|
|
||||||
MSI_LOG_ERROR << "Get first images input failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
if (images_inputs[0]->batch_size() != model_process_.GetBatchSize()) {
|
|
||||||
MSI_LOG_ERROR << "Input batch size " << images_inputs[0]->batch_size() << " not match Model batch size "
|
|
||||||
<< model_process_.GetBatchSize();
|
|
||||||
return INFER_STATUS(INVALID_INPUTS) << "Input batch size " << images_inputs[0]->batch_size()
|
|
||||||
<< " not match Model batch size " << model_process_.GetBatchSize();
|
|
||||||
}
|
|
||||||
if (request.size() != 0) {
|
|
||||||
MSI_LOG_ERROR << "only support one input, images input size is 1, tensor inputs is not 0 " << request.size();
|
|
||||||
return INFER_STATUS(INVALID_INPUTS) << "only support one input, images input size is 1, tensor inputs is not 0 "
|
|
||||||
<< request.size();
|
|
||||||
}
|
|
||||||
ImagesDvppOutput dvpp_output;
|
|
||||||
Status ret = PreProcess(model_id, images_inputs[0], dvpp_output);
|
|
||||||
if (ret != SUCCESS) {
|
|
||||||
MSI_LOG_ERROR << "DVPP preprocess failed";
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
ret = model_process_.Execute(dvpp_output.buffer_device, dvpp_output.buffer_size, reply);
|
|
||||||
if (ret != SUCCESS) {
|
|
||||||
MSI_LOG_ERROR << "Execute model failed";
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status AclSession::InitEnv(const std::string &device_type, uint32_t device_id) {
|
|
||||||
device_type_ = device_type;
|
|
||||||
device_id_ = device_id;
|
|
||||||
auto ret = aclInit(nullptr);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "Execute aclInit Failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
MSI_LOG_INFO << "acl init success";
|
|
||||||
|
|
||||||
ret = aclrtSetDevice(device_id_);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "acl open device " << device_id_ << " failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
MSI_LOG_INFO << "open device " << device_id_ << " success";
|
|
||||||
|
|
||||||
ret = aclrtCreateContext(&context_, device_id_);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "acl create context failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
MSI_LOG_INFO << "create context success";
|
|
||||||
|
|
||||||
ret = aclrtCreateStream(&stream_);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "acl create stream failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
MSI_LOG_INFO << "create stream success";
|
|
||||||
|
|
||||||
aclrtRunMode run_mode;
|
|
||||||
ret = aclrtGetRunMode(&run_mode);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "acl get run mode failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
bool is_device = (run_mode == ACL_DEVICE);
|
|
||||||
model_process_.SetIsDevice(is_device);
|
|
||||||
MSI_LOG_INFO << "get run mode success is device input/output " << is_device;
|
|
||||||
|
|
||||||
if (dvpp_process_.InitResource(stream_) != SUCCESS) {
|
|
||||||
MSI_LOG_ERROR << "dvpp init resource failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
MSI_LOG_INFO << "Init acl success, device id " << device_id_;
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status AclSession::FinalizeEnv() {
|
|
||||||
dvpp_process_.Finalize();
|
|
||||||
aclError ret;
|
|
||||||
if (stream_ != nullptr) {
|
|
||||||
ret = aclrtDestroyStream(stream_);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "destroy stream failed";
|
|
||||||
}
|
|
||||||
stream_ = nullptr;
|
|
||||||
}
|
|
||||||
MSI_LOG_INFO << "end to destroy stream";
|
|
||||||
if (context_ != nullptr) {
|
|
||||||
ret = aclrtDestroyContext(context_);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "destroy context failed";
|
|
||||||
}
|
|
||||||
context_ = nullptr;
|
|
||||||
}
|
|
||||||
MSI_LOG_INFO << "end to destroy context";
|
|
||||||
|
|
||||||
ret = aclrtResetDevice(device_id_);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "reset devie " << device_id_ << " failed";
|
|
||||||
}
|
|
||||||
MSI_LOG_INFO << "end to reset device " << device_id_;
|
|
||||||
|
|
||||||
ret = aclFinalize();
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "finalize acl failed";
|
|
||||||
}
|
|
||||||
MSI_LOG_INFO << "end to finalize acl";
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
AclSession::AclSession() = default;
|
|
||||||
} // namespace mindspore::inference
|
|
|
@ -1,57 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_SERVING_ACL_SESSION_H
|
|
||||||
#define MINDSPORE_SERVING_ACL_SESSION_H
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <utility>
|
|
||||||
#include <memory>
|
|
||||||
#include <map>
|
|
||||||
|
|
||||||
#include "include/inference.h"
|
|
||||||
#include "serving/acl/model_process.h"
|
|
||||||
#include "serving/acl/dvpp_process.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace inference {
|
|
||||||
class AclSession : public InferSession {
|
|
||||||
public:
|
|
||||||
AclSession();
|
|
||||||
|
|
||||||
Status InitEnv(const std::string &device_type, uint32_t device_id) override;
|
|
||||||
Status FinalizeEnv() override;
|
|
||||||
Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override;
|
|
||||||
Status UnloadModel(uint32_t model_id) override;
|
|
||||||
Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) override;
|
|
||||||
Status ExecuteModel(uint32_t model_id, const ImagesRequestBase &images_inputs, // images for preprocess
|
|
||||||
const RequestBase &request, ReplyBase &reply) override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::string device_type_;
|
|
||||||
int32_t device_id_;
|
|
||||||
aclrtStream stream_ = nullptr;
|
|
||||||
aclrtContext context_ = nullptr;
|
|
||||||
ModelProcess model_process_;
|
|
||||||
bool execute_with_dvpp_ = false;
|
|
||||||
DvppProcess dvpp_process_;
|
|
||||||
|
|
||||||
Status PreProcess(uint32_t model_id, const InferImagesBase *images_input, ImagesDvppOutput &dvpp_output);
|
|
||||||
};
|
|
||||||
} // namespace inference
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // MINDSPORE_SERVING_ACL_SESSION_H
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,159 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef INC_DVPP_PROCESS_ACL
|
|
||||||
#define INC_DVPP_PROCESS_ACL
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include "acl/acl.h"
|
|
||||||
#include "acl/acl_mdl.h"
|
|
||||||
#include "acl/acl_rt.h"
|
|
||||||
#include "acl/ops/acl_dvpp.h"
|
|
||||||
#include "include/inference.h"
|
|
||||||
|
|
||||||
namespace mindspore::inference {
|
|
||||||
|
|
||||||
struct DvppDecodePara {
|
|
||||||
acldvppPixelFormat pixel_format = PIXEL_FORMAT_YUV_SEMIPLANAR_420;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct DvppResizePara {
|
|
||||||
uint32_t output_width = 0;
|
|
||||||
uint32_t output_height = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
enum DvppCropType {
|
|
||||||
// crop left,top,right,bottom is given in config
|
|
||||||
kDvppCropTypeOffset = 0,
|
|
||||||
// crop left,top,right,bottom is calculated by image width/height and output crop width/height
|
|
||||||
kDvppCropTypeCentre = 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
struct DvppRoiArea {
|
|
||||||
uint32_t left = 0;
|
|
||||||
uint32_t top = 0;
|
|
||||||
uint32_t right = 0;
|
|
||||||
uint32_t bottom = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct DvppCropInfo {
|
|
||||||
DvppCropType crop_type = kDvppCropTypeOffset;
|
|
||||||
DvppRoiArea crop_area; // when kDvppCropTypeOffset
|
|
||||||
uint32_t crop_width = 0; // when kDvppCropTypeCentre
|
|
||||||
uint32_t crop_height = 0; // when kDvppCropTypeCentre
|
|
||||||
};
|
|
||||||
|
|
||||||
struct DvppCropPara {
|
|
||||||
DvppCropInfo crop_info;
|
|
||||||
uint32_t output_width = 0;
|
|
||||||
uint32_t output_height = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct DvppCropAndPastePara {
|
|
||||||
DvppCropInfo crop_info;
|
|
||||||
DvppRoiArea paste_area;
|
|
||||||
uint32_t output_width = 0;
|
|
||||||
uint32_t output_height = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
class DvppProcess {
|
|
||||||
public:
|
|
||||||
DvppProcess();
|
|
||||||
~DvppProcess();
|
|
||||||
|
|
||||||
Status InitResource(aclrtStream stream);
|
|
||||||
void Finalize();
|
|
||||||
Status InitJpegDecodePara(const DvppDecodePara &decode_para); // jpeg decode + (resize | crop)
|
|
||||||
Status InitResizePara(const DvppResizePara &resize_para); // jpeg decode + resize
|
|
||||||
Status InitCropPara(const DvppCropPara &crop_para); // jpeg decode + crop
|
|
||||||
Status InitCropAndPastePara(const DvppCropAndPastePara &crop_and_paste_para); // jpeg decode + crop&paste
|
|
||||||
|
|
||||||
Status InitWithJsonConfig(const std::string &json_config);
|
|
||||||
|
|
||||||
// output device buffer will be destroy by DvppProcess itself.
|
|
||||||
Status Process(const void *pic_buffer, size_t pic_buffer_size, void *&output_device_buffer, size_t &output_size);
|
|
||||||
Status Process(const std::vector<const void *> &pic_buffer_list, const std::vector<size_t> &pic_buffer_size_list,
|
|
||||||
void *&output_device_buffer, size_t &output_size);
|
|
||||||
|
|
||||||
private:
|
|
||||||
uint32_t pic_width_ = 0;
|
|
||||||
uint32_t pic_height_ = 0;
|
|
||||||
|
|
||||||
DvppDecodePara decode_para_;
|
|
||||||
DvppResizePara resize_para_;
|
|
||||||
DvppCropPara crop_para_;
|
|
||||||
DvppCropAndPastePara crop_and_paste_para_;
|
|
||||||
// only one of the resize or crop flag can be true
|
|
||||||
bool to_resize_flag_ = false;
|
|
||||||
bool to_crop_flag_ = false;
|
|
||||||
bool to_crop_and_paste_flag_ = false;
|
|
||||||
|
|
||||||
void *input_pic_dev_buffer_ = nullptr;
|
|
||||||
uint32_t input_pic_buffer_size_ = 0;
|
|
||||||
|
|
||||||
uint32_t decode_output_buffer_size_ = 0;
|
|
||||||
void *decode_output_buffer_dev_ = nullptr;
|
|
||||||
acldvppPicDesc *decode_output_desc_ = nullptr;
|
|
||||||
|
|
||||||
acldvppResizeConfig *resize_config_ = nullptr;
|
|
||||||
acldvppRoiConfig *crop_area_ = nullptr;
|
|
||||||
acldvppRoiConfig *paste_area_ = nullptr;
|
|
||||||
|
|
||||||
acldvppPicDesc *vpc_output_desc_ = nullptr;
|
|
||||||
void *vpc_output_buffer_dev_ = nullptr; // vpc_output_buffer_size_ length
|
|
||||||
uint32_t vpc_output_buffer_size_ = 0;
|
|
||||||
|
|
||||||
void *batch_vpc_output_buffer_dev_ = nullptr; // batch_size_ * vpc_output_buffer_size_ length
|
|
||||||
uint32_t batch_size_ = 0;
|
|
||||||
|
|
||||||
aclrtStream stream_ = nullptr;
|
|
||||||
acldvppChannelDesc *dvpp_channel_desc_ = nullptr;
|
|
||||||
|
|
||||||
uint32_t AlignmentHelper(uint32_t org_size, uint32_t alignment) const;
|
|
||||||
uint32_t GetImageBufferSize(uint32_t stride_width, uint32_t stride_height, acldvppPixelFormat pixel_format) const;
|
|
||||||
Status GetPicDescStride(uint32_t width, uint32_t height, uint32_t &stride_width, uint32_t &stride_height);
|
|
||||||
Status GetPicDescStrideDecode(uint32_t width, uint32_t height, uint32_t &stride_width, uint32_t &stride_height);
|
|
||||||
Status InputInputBuffer(const void *pic_buffer, size_t pic_buffer_size);
|
|
||||||
Status InitDecodeOutputDesc(uint32_t image_width,
|
|
||||||
uint32_t image_height); // decode_output_desc_, decode_output_buffer_dev_
|
|
||||||
Status CheckRoiAreaWidthHeight(uint32_t width, uint32_t height);
|
|
||||||
Status CheckAndAdjustRoiArea(DvppRoiArea &area);
|
|
||||||
Status UpdateCropArea(uint32_t image_width, uint32_t image_height);
|
|
||||||
Status CheckResizeImageInfo(uint32_t image_width, uint32_t image_height) const;
|
|
||||||
void DestroyDecodeDesc();
|
|
||||||
|
|
||||||
Status InitVpcOutputDesc(uint32_t output_width, uint32_t output_height,
|
|
||||||
acldvppPixelFormat pixel_format); // vpc_output_desc_, vpc_output_buffer_dev_batch_
|
|
||||||
Status InitRoiAreaConfig(acldvppRoiConfig *&roi_area, const DvppRoiArea &init_para);
|
|
||||||
Status InitCommonCropPara(DvppCropInfo &crop_info, uint32_t out_width, uint32_t out_height);
|
|
||||||
Status InitResizeOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, resize_config
|
|
||||||
Status InitCropOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, crop_area_
|
|
||||||
Status InitCropAndPasteOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, crop_area_, paste_area_
|
|
||||||
void DestroyVpcOutputDesc();
|
|
||||||
|
|
||||||
Status ProcessDecode();
|
|
||||||
Status ProcessResize();
|
|
||||||
Status ProcessCrop();
|
|
||||||
Status ProcessCropAndPaste();
|
|
||||||
void DestroyResource();
|
|
||||||
|
|
||||||
Status GetJpegWidthHeight(const void *pic_buffer, size_t pic_buffer_size, uint32_t &image_width,
|
|
||||||
uint32_t &image_height);
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace mindspore::inference
|
|
||||||
|
|
||||||
#endif // INC_DVPP_PROCESS_ACL
|
|
|
@ -1,68 +0,0 @@
|
||||||
{
|
|
||||||
"preprocess": [
|
|
||||||
{
|
|
||||||
"input": {
|
|
||||||
"index": 0
|
|
||||||
},
|
|
||||||
"decode_para": {
|
|
||||||
"out_pixel_format": "YUV420SP"
|
|
||||||
},
|
|
||||||
"dvpp_process": {
|
|
||||||
"op_name": "resize",
|
|
||||||
"out_width": 224,
|
|
||||||
"out_height": 224
|
|
||||||
},
|
|
||||||
"sample of dvpp_process content": [
|
|
||||||
{
|
|
||||||
"op_name": "resize",
|
|
||||||
"out_width": 224,
|
|
||||||
"out_height": 224
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"op_name": "crop",
|
|
||||||
"crop_type": "offset",
|
|
||||||
"crop_left": 10,
|
|
||||||
"crop_top": 10,
|
|
||||||
"crop_right": 100,
|
|
||||||
"crop_bottom": 200,
|
|
||||||
"out_width": 224,
|
|
||||||
"out_height": 224
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"op_name": "crop",
|
|
||||||
"crop_type": "centre",
|
|
||||||
"crop_width": 100,
|
|
||||||
"crop_height": 100,
|
|
||||||
"out_width": 224,
|
|
||||||
"out_height": 224
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"op_name": "crop_and_paste",
|
|
||||||
"crop_type": "offset",
|
|
||||||
"crop_left": 10,
|
|
||||||
"crop_top": 10,
|
|
||||||
"crop_right": 100,
|
|
||||||
"crop_bottom": 200,
|
|
||||||
"paste_left": 10,
|
|
||||||
"paste_top": 10,
|
|
||||||
"paste_right": 100,
|
|
||||||
"paste_bottom": 200,
|
|
||||||
"out_width": 224,
|
|
||||||
"out_height": 224
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"op_name": "crop_and_paste",
|
|
||||||
"crop_type": "centre",
|
|
||||||
"crop_width": 100,
|
|
||||||
"crop_height": 100,
|
|
||||||
"paste_left": 10,
|
|
||||||
"paste_top": 10,
|
|
||||||
"paste_right": 100,
|
|
||||||
"paste_bottom": 200,
|
|
||||||
"out_width": 224,
|
|
||||||
"out_height": 224
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
|
@ -1,431 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "serving/acl/model_process.h"
|
|
||||||
#include <algorithm>
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
#include "include/infer_log.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace inference {
|
|
||||||
Status ModelProcess::PreInitModelResource() {
|
|
||||||
model_desc_ = aclmdlCreateDesc();
|
|
||||||
aclError acl_ret = aclmdlGetDesc(model_desc_, model_id_);
|
|
||||||
if (acl_ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "Read model desc failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
Status ret = InitInputsBuffer();
|
|
||||||
if (ret != SUCCESS) {
|
|
||||||
MSI_LOG_ERROR << "Create input buffer failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
ret = InitOutputsBuffer();
|
|
||||||
if (ret != SUCCESS) {
|
|
||||||
MSI_LOG_ERROR << "Create output buffer failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ModelProcess::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
|
|
||||||
aclError acl_ret = aclmdlLoadFromFile(file_name.c_str(), &model_id);
|
|
||||||
if (acl_ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "Read model file failed, file name is " << file_name;
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
MSI_LOG_INFO << "Load model success " << file_name;
|
|
||||||
model_id_ = model_id;
|
|
||||||
if (PreInitModelResource() != SUCCESS) {
|
|
||||||
aclmdlUnload(model_id_);
|
|
||||||
MSI_LOG_ERROR << "Pre init model resource failed, file name is " << file_name;
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ModelProcess::InitInputsBuffer() {
|
|
||||||
aclError ret;
|
|
||||||
size_t input_size = aclmdlGetNumInputs(model_desc_);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < input_size; ++i) {
|
|
||||||
auto buffer_size = aclmdlGetInputSizeByIndex(model_desc_, i);
|
|
||||||
void *data_mem_buffer = nullptr;
|
|
||||||
if (!is_run_on_device_) { // need to copy input/output to/from device
|
|
||||||
ret = aclrtMalloc(&data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "Malloc device input buffer faild , input size " << buffer_size;
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
aclmdlIODims dims;
|
|
||||||
ret = aclmdlGetInputDims(model_desc_, i, &dims);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "Get input shape failed";
|
|
||||||
if (!is_run_on_device_) {
|
|
||||||
aclrtFree(data_mem_buffer);
|
|
||||||
}
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
aclDataType data_type = aclmdlGetInputDataType(model_desc_, i);
|
|
||||||
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
|
|
||||||
input_infos_.emplace_back(AclTensorInfo{data_mem_buffer, buffer_size, data_type, shape});
|
|
||||||
}
|
|
||||||
MSI_LOG_INFO << "Create model inputs success";
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ModelProcess::CreateDataBuffer(void *&data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset) {
|
|
||||||
aclError ret;
|
|
||||||
auto free_data_buffer = [this](void *dataMemBuffer) {
|
|
||||||
if (!is_run_on_device_) {
|
|
||||||
aclrtFree(dataMemBuffer);
|
|
||||||
} else {
|
|
||||||
aclrtFreeHost(dataMemBuffer);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
if (!is_run_on_device_) {
|
|
||||||
ret = aclrtMalloc(&data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "Malloc device buffer faild , buffer size " << buffer_size;
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ret = aclrtMallocHost(&data_mem_buffer, buffer_size);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "Malloc device buffer faild , buffer size " << buffer_size;
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto data_buffer = aclCreateDataBuffer(data_mem_buffer, buffer_size);
|
|
||||||
if (data_buffer == nullptr) {
|
|
||||||
MSI_LOG_ERROR << "Create Data Buffer failed";
|
|
||||||
free_data_buffer(data_mem_buffer);
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
ret = aclmdlAddDatasetBuffer(dataset, data_buffer);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "add data buffer failed";
|
|
||||||
free_data_buffer(data_mem_buffer);
|
|
||||||
aclDestroyDataBuffer(data_buffer);
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ModelProcess::InitOutputsBuffer() {
|
|
||||||
aclError ret;
|
|
||||||
outputs_ = aclmdlCreateDataset();
|
|
||||||
if (outputs_ == nullptr) {
|
|
||||||
MSI_LOG_ERROR << "Create input dataset failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
size_t output_size = aclmdlGetNumOutputs(model_desc_);
|
|
||||||
for (size_t i = 0; i < output_size; ++i) {
|
|
||||||
auto buffer_size = aclmdlGetOutputSizeByIndex(model_desc_, i);
|
|
||||||
|
|
||||||
void *data_mem_buffer = nullptr;
|
|
||||||
if (CreateDataBuffer(data_mem_buffer, buffer_size, outputs_) != SUCCESS) {
|
|
||||||
MSI_LOG_ERROR << "add output data buffer failed, buffer size " << buffer_size;
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
aclmdlIODims dims;
|
|
||||||
ret = aclmdlGetOutputDims(model_desc_, i, &dims);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "Get input shape failed";
|
|
||||||
if (!is_run_on_device_) {
|
|
||||||
aclrtFree(data_mem_buffer);
|
|
||||||
} else {
|
|
||||||
aclrtFreeHost(data_mem_buffer);
|
|
||||||
}
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
aclDataType data_type = aclmdlGetOutputDataType(model_desc_, i);
|
|
||||||
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
|
|
||||||
output_infos_.emplace_back(AclTensorInfo{data_mem_buffer, buffer_size, data_type, shape});
|
|
||||||
}
|
|
||||||
MSI_LOG_INFO << "Create model output success";
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
void ModelProcess::DestroyInputsDataset() {
|
|
||||||
if (inputs_ == nullptr) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(inputs_); i++) {
|
|
||||||
auto dataBuffer = aclmdlGetDatasetBuffer(inputs_, i);
|
|
||||||
aclDestroyDataBuffer(dataBuffer);
|
|
||||||
}
|
|
||||||
aclmdlDestroyDataset(inputs_);
|
|
||||||
inputs_ = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
void ModelProcess::DestroyInputsDataMem() {
|
|
||||||
if (!is_run_on_device_) {
|
|
||||||
for (const auto &item : input_infos_) {
|
|
||||||
aclrtFree(item.device_data);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
input_infos_.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
void ModelProcess::DestroyInputsBuffer() {
|
|
||||||
DestroyInputsDataMem();
|
|
||||||
DestroyInputsDataset();
|
|
||||||
}
|
|
||||||
|
|
||||||
void ModelProcess::DestroyOutputsBuffer() {
|
|
||||||
for (const auto &item : output_infos_) {
|
|
||||||
if (!is_run_on_device_) {
|
|
||||||
aclrtFree(item.device_data);
|
|
||||||
} else {
|
|
||||||
aclrtFreeHost(item.device_data);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
output_infos_.clear();
|
|
||||||
|
|
||||||
if (outputs_ == nullptr) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(outputs_); i++) {
|
|
||||||
auto dataBuffer = aclmdlGetDatasetBuffer(outputs_, i);
|
|
||||||
aclDestroyDataBuffer(dataBuffer);
|
|
||||||
}
|
|
||||||
aclmdlDestroyDataset(outputs_);
|
|
||||||
outputs_ = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
void ModelProcess::UnLoad() {
|
|
||||||
auto ret = aclmdlUnload(model_id_);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "Unload model failed";
|
|
||||||
}
|
|
||||||
if (model_desc_ != nullptr) {
|
|
||||||
aclmdlDestroyDesc(model_desc_);
|
|
||||||
model_desc_ = nullptr;
|
|
||||||
}
|
|
||||||
DestroyInputsBuffer();
|
|
||||||
DestroyOutputsBuffer();
|
|
||||||
MSI_LOG_INFO << "End unload model " << model_id_;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ModelProcess::CheckAndInitInput(const RequestBase &request) {
|
|
||||||
aclError ret;
|
|
||||||
inputs_ = aclmdlCreateDataset();
|
|
||||||
// check inputs
|
|
||||||
if (request.size() != input_infos_.size()) {
|
|
||||||
MSI_LOG_ERROR << "inputs count not match, required count " << input_infos_.size() << ", given count "
|
|
||||||
<< request.size();
|
|
||||||
return INFER_STATUS(INVALID_INPUTS) << "inputs count not match, required count " << input_infos_.size()
|
|
||||||
<< ", given count " << request.size();
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < input_infos_.size(); i++) {
|
|
||||||
if (request[i] == nullptr) {
|
|
||||||
MSI_LOG_ERROR << "input " << i << " cannot be null";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
if (request[i]->data_size() != input_infos_[i].buffer_size) {
|
|
||||||
MSI_LOG_ERROR << "input " << i << " data size not match, required size " << input_infos_[i].buffer_size
|
|
||||||
<< ", given count " << request[i]->data_size();
|
|
||||||
return INFER_STATUS(INVALID_INPUTS) << "input " << i << " data size not match, required size "
|
|
||||||
<< input_infos_[i].buffer_size << ", given count " << request[i]->data_size();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// copy inputs
|
|
||||||
for (size_t i = 0; i < input_infos_.size(); i++) {
|
|
||||||
void *input_buffer = nullptr;
|
|
||||||
auto &info = input_infos_[i];
|
|
||||||
const void *data = request[i]->data();
|
|
||||||
if (!is_run_on_device_) {
|
|
||||||
ret = aclrtMemcpy(info.device_data, info.buffer_size, data, request[i]->data_size(), ACL_MEMCPY_HOST_TO_DEVICE);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "memcpy input " << i << " data to device failed, buffer size " << request[i]->data_size();
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
input_buffer = info.device_data;
|
|
||||||
} else {
|
|
||||||
input_buffer = const_cast<void *>(data);
|
|
||||||
}
|
|
||||||
auto data_buffer = aclCreateDataBuffer(input_buffer, info.buffer_size);
|
|
||||||
if (data_buffer == nullptr) {
|
|
||||||
MSI_LOG_ERROR << "Create Data Buffer failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
ret = aclmdlAddDatasetBuffer(inputs_, data_buffer);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "add data buffer failed";
|
|
||||||
aclDestroyDataBuffer(data_buffer);
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ModelProcess::CheckAndInitDvppInput(const void *dvpp_outputs_buffer_dev, size_t dvpp_outputs_buffer_size,
|
|
||||||
size_t input_index) {
|
|
||||||
aclError ret;
|
|
||||||
inputs_ = aclmdlCreateDataset();
|
|
||||||
// check inputs
|
|
||||||
if (input_index >= input_infos_.size()) {
|
|
||||||
MSI_LOG_ERROR << "inputs count not match, required count " << input_infos_.size() << ", given index "
|
|
||||||
<< input_index;
|
|
||||||
return INFER_STATUS(INVALID_INPUTS) << "inputs count not match, required count " << input_infos_.size()
|
|
||||||
<< ", given index " << input_index;
|
|
||||||
}
|
|
||||||
if (dvpp_outputs_buffer_dev == nullptr) {
|
|
||||||
MSI_LOG_ERROR << "input " << 0 << " cannot be null";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
if (dvpp_outputs_buffer_size != input_infos_[input_index].buffer_size) {
|
|
||||||
MSI_LOG_ERROR << "input " << 0 << " data size not match, required size " << input_infos_[input_index].buffer_size
|
|
||||||
<< ", given count " << dvpp_outputs_buffer_size;
|
|
||||||
return INFER_STATUS(INVALID_INPUTS) << "input " << 0 << " data size not match, required size "
|
|
||||||
<< input_infos_[input_index].buffer_size << ", given count "
|
|
||||||
<< dvpp_outputs_buffer_size;
|
|
||||||
}
|
|
||||||
// copy inputs
|
|
||||||
auto &info = input_infos_[input_index];
|
|
||||||
auto data_buffer = aclCreateDataBuffer(const_cast<void *>(dvpp_outputs_buffer_dev), info.buffer_size);
|
|
||||||
if (data_buffer == nullptr) {
|
|
||||||
MSI_LOG_ERROR << "Create Data Buffer failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
ret = aclmdlAddDatasetBuffer(inputs_, data_buffer);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "add data buffer failed";
|
|
||||||
aclDestroyDataBuffer(data_buffer);
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ModelProcess::BuildOutputs(ReplyBase &reply) {
|
|
||||||
aclError ret;
|
|
||||||
// copy outputs
|
|
||||||
reply.clear();
|
|
||||||
|
|
||||||
std::unordered_map<aclDataType, inference::DataType> data_type_map = {
|
|
||||||
{ACL_FLOAT16, inference::kMSI_Float16}, {ACL_FLOAT, inference::kMSI_Float32}, {ACL_DOUBLE, inference::kMSI_Float64},
|
|
||||||
{ACL_INT8, inference::kMSI_Int8}, {ACL_INT16, inference::kMSI_Int16}, {ACL_INT32, inference::kMSI_Int32},
|
|
||||||
{ACL_INT64, inference::kMSI_Int64}, {ACL_UINT8, inference::kMSI_Uint8}, {ACL_UINT16, inference::kMSI_Uint16},
|
|
||||||
{ACL_UINT32, inference::kMSI_Uint32}, {ACL_UINT64, inference::kMSI_Uint64}, {ACL_BOOL, inference::kMSI_Bool},
|
|
||||||
};
|
|
||||||
auto trans_to_serving_type = [&data_type_map](aclDataType data_type) {
|
|
||||||
auto it = data_type_map.find(data_type);
|
|
||||||
if (it == data_type_map.end()) {
|
|
||||||
return inference::kMSI_Unknown;
|
|
||||||
} else {
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
for (size_t i = 0; i < output_infos_.size(); i++) {
|
|
||||||
auto &info = output_infos_[i];
|
|
||||||
auto output = reply.add();
|
|
||||||
if (output == nullptr) {
|
|
||||||
MSI_LOG_ERROR << "add new output failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
output->set_data_type(trans_to_serving_type(info.data_type));
|
|
||||||
output->set_shape(info.dims);
|
|
||||||
if (!output->resize_data(info.buffer_size)) {
|
|
||||||
MSI_LOG_ERROR << "new output data buffer failed, data size " << info.buffer_size;
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
if (!is_run_on_device_) {
|
|
||||||
ret = aclrtMemcpy(output->mutable_data(), output->data_size(), info.device_data, info.buffer_size,
|
|
||||||
ACL_MEMCPY_DEVICE_TO_HOST);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "Memcpy output " << i << " to host failed, memory size " << info.buffer_size;
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ret = aclrtMemcpy(output->mutable_data(), output->data_size(), info.device_data, info.buffer_size,
|
|
||||||
ACL_MEMCPY_HOST_TO_HOST);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "Memcpy output " << i << " to host failed, memory size " << info.buffer_size;
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ModelProcess::Execute(const RequestBase &request, ReplyBase &reply) {
|
|
||||||
aclError acl_ret;
|
|
||||||
Status ret = CheckAndInitInput(request);
|
|
||||||
if (ret != SUCCESS) {
|
|
||||||
MSI_LOG_ERROR << "check or init input failed";
|
|
||||||
DestroyInputsDataset();
|
|
||||||
return ret; // forward status error
|
|
||||||
}
|
|
||||||
acl_ret = aclmdlExecute(model_id_, inputs_, outputs_);
|
|
||||||
DestroyInputsDataset();
|
|
||||||
if (acl_ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "Execute Model Failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
ret = BuildOutputs(reply);
|
|
||||||
if (ret != SUCCESS) {
|
|
||||||
MSI_LOG_ERROR << "Build outputs faield";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
MSI_LOG_INFO << "excute model success";
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ModelProcess::Execute(const void *dvpp_outputs_buffer_dev, size_t dvpp_outputs_buffer_size, ReplyBase &reply) {
|
|
||||||
aclError acl_ret;
|
|
||||||
if (input_infos_.size() != 1) {
|
|
||||||
MSI_LOG_ERROR << "can only support input size 1, now model inputs size is " << input_infos_.size();
|
|
||||||
return INFER_STATUS(INVALID_INPUTS) << "can only support input size 1, now model inputs size is "
|
|
||||||
<< input_infos_.size();
|
|
||||||
}
|
|
||||||
Status ret = CheckAndInitDvppInput(dvpp_outputs_buffer_dev, dvpp_outputs_buffer_size, 0);
|
|
||||||
if (ret != SUCCESS) {
|
|
||||||
MSI_LOG_ERROR << "check or init input failed";
|
|
||||||
DestroyInputsDataset();
|
|
||||||
return ret; // forward status msg
|
|
||||||
}
|
|
||||||
acl_ret = aclmdlExecute(model_id_, inputs_, outputs_);
|
|
||||||
DestroyInputsDataset();
|
|
||||||
if (acl_ret != ACL_ERROR_NONE) {
|
|
||||||
MSI_LOG_ERROR << "Execute Model Failed";
|
|
||||||
return INFER_STATUS(FAILED) << "Execute Model Failed";
|
|
||||||
}
|
|
||||||
ret = BuildOutputs(reply);
|
|
||||||
if (ret != SUCCESS) {
|
|
||||||
MSI_LOG_ERROR << "Build outputs faield";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
MSI_LOG_INFO << "excute model success";
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t ModelProcess::GetBatchSize() const {
|
|
||||||
if (input_infos_.empty()) {
|
|
||||||
MSI_LOG_ERROR << "Model is not loaded";
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
if (input_infos_[0].dims.empty()) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
return static_cast<size_t>(input_infos_[0].dims[0]);
|
|
||||||
}
|
|
||||||
} // namespace inference
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,83 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef INC_MODEL_PROCESS_ACL
|
|
||||||
#define INC_MODEL_PROCESS_ACL
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include "acl/acl.h"
|
|
||||||
#include "acl/acl_mdl.h"
|
|
||||||
#include "acl/acl_rt.h"
|
|
||||||
#include "include/inference.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace inference {
|
|
||||||
struct AclTensorInfo {
|
|
||||||
void *device_data;
|
|
||||||
size_t buffer_size;
|
|
||||||
aclDataType data_type;
|
|
||||||
std::vector<int64_t> dims;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ImagesDvppOutput {
|
|
||||||
void *buffer_device = nullptr;
|
|
||||||
size_t buffer_size = 0;
|
|
||||||
size_t input_index = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
class ModelProcess {
|
|
||||||
public:
|
|
||||||
ModelProcess() {}
|
|
||||||
~ModelProcess() {}
|
|
||||||
|
|
||||||
Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id);
|
|
||||||
void UnLoad();
|
|
||||||
|
|
||||||
// override this method to avoid request/reply data copy
|
|
||||||
Status Execute(const RequestBase &request, ReplyBase &reply);
|
|
||||||
Status Execute(const void *dvpp_outputs_buffer_dev, size_t dvpp_outputs_buffer_size, ReplyBase &reply);
|
|
||||||
void SetIsDevice(bool is_device) { is_run_on_device_ = is_device; }
|
|
||||||
|
|
||||||
size_t GetBatchSize() const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
uint32_t model_id_ = 0xffffffff;
|
|
||||||
// if run one device(AICPU), there is no need to alloc device memory and copy inputs to(/outputs from) device
|
|
||||||
bool is_run_on_device_ = false;
|
|
||||||
aclmdlDesc *model_desc_ = nullptr;
|
|
||||||
aclmdlDataset *inputs_ = nullptr;
|
|
||||||
aclmdlDataset *outputs_ = nullptr;
|
|
||||||
std::vector<AclTensorInfo> input_infos_;
|
|
||||||
std::vector<AclTensorInfo> output_infos_;
|
|
||||||
|
|
||||||
Status PreInitModelResource();
|
|
||||||
Status CreateDataBuffer(void *&data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset);
|
|
||||||
Status CheckAndInitInput(const RequestBase &request);
|
|
||||||
Status CheckAndInitDvppInput(const void *dvpp_outputs_buffer_dev, size_t dvpp_outputs_buffer_size,
|
|
||||||
size_t input_index);
|
|
||||||
Status BuildOutputs(ReplyBase &reply);
|
|
||||||
|
|
||||||
Status InitInputsBuffer();
|
|
||||||
Status InitOutputsBuffer();
|
|
||||||
void DestroyInputsDataset();
|
|
||||||
void DestroyInputsDataMem();
|
|
||||||
void DestroyInputsBuffer();
|
|
||||||
void DestroyOutputsBuffer();
|
|
||||||
};
|
|
||||||
} // namespace inference
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif
|
|
|
@ -1,561 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include <map>
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <functional>
|
|
||||||
#include <utility>
|
|
||||||
#include <nlohmann/json.hpp>
|
|
||||||
#include "serving/ms_service.pb.h"
|
|
||||||
#include "util/status.h"
|
|
||||||
#include "core/session.h"
|
|
||||||
#include "core/http_process.h"
|
|
||||||
#include "core/serving_tensor.h"
|
|
||||||
|
|
||||||
using ms_serving::MSService;
|
|
||||||
using ms_serving::PredictReply;
|
|
||||||
using ms_serving::PredictRequest;
|
|
||||||
using nlohmann::json;
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace serving {
|
|
||||||
const int BUF_MAX = 0x7FFFFFFF;
|
|
||||||
static constexpr char HTTP_DATA[] = "data";
|
|
||||||
static constexpr char HTTP_TENSOR[] = "tensor";
|
|
||||||
enum HTTP_TYPE { TYPE_DATA = 0, TYPE_TENSOR };
|
|
||||||
enum HTTP_DATA_TYPE { HTTP_DATA_NONE, HTTP_DATA_INT, HTTP_DATA_FLOAT };
|
|
||||||
|
|
||||||
static const std::map<inference::DataType, HTTP_DATA_TYPE> infer_type2_http_type{
|
|
||||||
{inference::DataType::kMSI_Int32, HTTP_DATA_INT}, {inference::DataType::kMSI_Float32, HTTP_DATA_FLOAT}};
|
|
||||||
|
|
||||||
Status GetPostMessage(struct evhttp_request *const req, std::string *const buf) {
|
|
||||||
Status status(SUCCESS);
|
|
||||||
size_t post_size = evbuffer_get_length(req->input_buffer);
|
|
||||||
if (post_size == 0) {
|
|
||||||
ERROR_INFER_STATUS(status, INVALID_INPUTS, "http message invalid");
|
|
||||||
return status;
|
|
||||||
} else if (post_size > BUF_MAX) {
|
|
||||||
ERROR_INFER_STATUS(status, INVALID_INPUTS, "http message is bigger than 0x7FFFFFFF.");
|
|
||||||
return status;
|
|
||||||
} else {
|
|
||||||
buf->resize(post_size);
|
|
||||||
auto src_data = evbuffer_pullup(req->input_buffer, -1);
|
|
||||||
if (src_data == nullptr) {
|
|
||||||
ERROR_INFER_STATUS(status, FAILED, "get http message failed.");
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
if (memcpy_s(buf->data(), post_size, src_data, post_size) != EOK) {
|
|
||||||
ERROR_INFER_STATUS(status, FAILED, "copy http message failed.");
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Status CheckRequestValid(const struct evhttp_request *const http_request) {
|
|
||||||
Status status(SUCCESS);
|
|
||||||
switch (evhttp_request_get_command(http_request)) {
|
|
||||||
case EVHTTP_REQ_POST:
|
|
||||||
return status;
|
|
||||||
default:
|
|
||||||
ERROR_INFER_STATUS(status, INVALID_INPUTS, "http message only support POST right now");
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ErrorMessage(struct evhttp_request *const req, Status status) {
|
|
||||||
json error_json = {{"error_message", status.StatusMessage()}};
|
|
||||||
std::string out_error_str = error_json.dump();
|
|
||||||
struct evbuffer *retbuff = evbuffer_new();
|
|
||||||
evbuffer_add(retbuff, out_error_str.data(), out_error_str.size());
|
|
||||||
evhttp_send_reply(req, HTTP_OK, "Client", retbuff);
|
|
||||||
evbuffer_free(retbuff);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status CheckMessageValid(const json &message_info, HTTP_TYPE *const type) {
|
|
||||||
Status status(SUCCESS);
|
|
||||||
int count = 0;
|
|
||||||
if (message_info.find(HTTP_DATA) != message_info.end()) {
|
|
||||||
*type = TYPE_DATA;
|
|
||||||
count++;
|
|
||||||
}
|
|
||||||
if (message_info.find(HTTP_TENSOR) != message_info.end()) {
|
|
||||||
*type = TYPE_TENSOR;
|
|
||||||
count++;
|
|
||||||
}
|
|
||||||
if (count != 1) {
|
|
||||||
ERROR_INFER_STATUS(status, INVALID_INPUTS, "http message must have only one type of (data, tensor)");
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int64_t> GetJsonArrayShape(const json &json_array) {
|
|
||||||
std::vector<int64_t> json_shape;
|
|
||||||
const json *tmp_json = &json_array;
|
|
||||||
while (tmp_json->is_array()) {
|
|
||||||
if (tmp_json->empty()) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
json_shape.push_back(tmp_json->size());
|
|
||||||
tmp_json = &tmp_json->at(0);
|
|
||||||
}
|
|
||||||
return json_shape;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GetScalarDataFromJson(const json &json_data_array, ServingTensor *const request_tensor, HTTP_DATA_TYPE type) {
|
|
||||||
Status status(SUCCESS);
|
|
||||||
auto type_name = [](const json &json_data) -> std::string {
|
|
||||||
if (json_data.is_number_integer()) {
|
|
||||||
return "integer";
|
|
||||||
} else if (json_data.is_number_float()) {
|
|
||||||
return "float";
|
|
||||||
}
|
|
||||||
return json_data.type_name();
|
|
||||||
};
|
|
||||||
const json *json_data = &json_data_array;
|
|
||||||
if (json_data_array.is_array()) {
|
|
||||||
if (json_data_array.size() != 1 || json_data_array[0].is_array()) {
|
|
||||||
status = INFER_STATUS(INVALID_INPUTS) << "get data failed, expected scalar data is scalar or shape(1) array, "
|
|
||||||
"now array shape is "
|
|
||||||
<< GetJsonArrayShape(json_data_array);
|
|
||||||
MSI_LOG_ERROR << status.StatusMessage();
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
json_data = &json_data_array.at(0);
|
|
||||||
}
|
|
||||||
if (type == HTTP_DATA_INT) {
|
|
||||||
auto data = reinterpret_cast<int32_t *>(request_tensor->mutable_data());
|
|
||||||
if (!json_data->is_number_integer()) {
|
|
||||||
status = INFER_STATUS(INVALID_INPUTS) << "get data failed, expected integer, given " << type_name(*json_data);
|
|
||||||
MSI_LOG_ERROR << status.StatusMessage();
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
data[0] = json_data->get<int32_t>();
|
|
||||||
} else if (type == HTTP_DATA_FLOAT) {
|
|
||||||
auto data = reinterpret_cast<float *>(request_tensor->mutable_data());
|
|
||||||
if (!json_data->is_number_float()) {
|
|
||||||
status = INFER_STATUS(INVALID_INPUTS) << "get data failed, expected float, given " << type_name(*json_data);
|
|
||||||
MSI_LOG_ERROR << status.StatusMessage();
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
data[0] = json_data->get<float>();
|
|
||||||
}
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GetDataFromJson(const json &json_data_array, ServingTensor *const request_tensor, size_t data_index,
|
|
||||||
HTTP_DATA_TYPE type) {
|
|
||||||
Status status(SUCCESS);
|
|
||||||
auto type_name = [](const json &json_data) -> std::string {
|
|
||||||
if (json_data.is_number_integer()) {
|
|
||||||
return "integer";
|
|
||||||
} else if (json_data.is_number_float()) {
|
|
||||||
return "float";
|
|
||||||
}
|
|
||||||
return json_data.type_name();
|
|
||||||
};
|
|
||||||
size_t array_size = json_data_array.size();
|
|
||||||
if (type == HTTP_DATA_INT) {
|
|
||||||
auto data = reinterpret_cast<int32_t *>(request_tensor->mutable_data()) + data_index;
|
|
||||||
for (size_t k = 0; k < array_size; k++) {
|
|
||||||
auto &json_data = json_data_array[k];
|
|
||||||
if (!json_data.is_number_integer()) {
|
|
||||||
status = INFER_STATUS(INVALID_INPUTS) << "get data failed, expected integer, given " << type_name(json_data);
|
|
||||||
MSI_LOG_ERROR << status.StatusMessage();
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
data[k] = json_data.get<int32_t>();
|
|
||||||
}
|
|
||||||
} else if (type == HTTP_DATA_FLOAT) {
|
|
||||||
auto data = reinterpret_cast<float *>(request_tensor->mutable_data()) + data_index;
|
|
||||||
for (size_t k = 0; k < array_size; k++) {
|
|
||||||
auto &json_data = json_data_array[k];
|
|
||||||
if (!json_data.is_number_float()) {
|
|
||||||
status = INFER_STATUS(INVALID_INPUTS) << "get data failed, expected float, given " << type_name(json_data);
|
|
||||||
MSI_LOG_ERROR << status.StatusMessage();
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
data[k] = json_data.get<float>();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status RecusiveGetTensor(const json &json_data, size_t depth, ServingTensor *const request_tensor, size_t data_index,
|
|
||||||
HTTP_DATA_TYPE type) {
|
|
||||||
Status status(SUCCESS);
|
|
||||||
std::vector<int64_t> required_shape = request_tensor->shape();
|
|
||||||
if (depth >= required_shape.size()) {
|
|
||||||
status = INFER_STATUS(INVALID_INPUTS)
|
|
||||||
<< "input tensor shape dims is more than required dims " << required_shape.size();
|
|
||||||
MSI_LOG_ERROR << status.StatusMessage();
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
if (!json_data.is_array()) {
|
|
||||||
ERROR_INFER_STATUS(status, INVALID_INPUTS, "the tensor is constructed illegally");
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
if (json_data.size() != static_cast<size_t>(required_shape[depth])) {
|
|
||||||
status = INFER_STATUS(INVALID_INPUTS)
|
|
||||||
<< "tensor format request is constructed illegally, input tensor shape dim " << depth
|
|
||||||
<< " not match, required " << required_shape[depth] << ", given " << json_data.size();
|
|
||||||
MSI_LOG_ERROR << status.StatusMessage();
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
if (depth + 1 < required_shape.size()) {
|
|
||||||
size_t sub_element_cnt =
|
|
||||||
std::accumulate(required_shape.begin() + depth + 1, required_shape.end(), 1LL, std::multiplies<size_t>());
|
|
||||||
for (size_t k = 0; k < json_data.size(); k++) {
|
|
||||||
status = RecusiveGetTensor(json_data[k], depth + 1, request_tensor, data_index + sub_element_cnt * k, type);
|
|
||||||
if (status != SUCCESS) {
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
status = GetDataFromJson(json_data, request_tensor, data_index, type);
|
|
||||||
if (status != SUCCESS) {
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status TransDataToPredictRequest(const json &message_info, PredictRequest *const request) {
|
|
||||||
Status status = SUCCESS;
|
|
||||||
auto tensors = message_info.find(HTTP_DATA);
|
|
||||||
if (tensors == message_info.end()) {
|
|
||||||
ERROR_INFER_STATUS(status, INVALID_INPUTS, "http message do not have data type");
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
if (!tensors->is_array()) {
|
|
||||||
ERROR_INFER_STATUS(status, INVALID_INPUTS, "the input tensor list is not array");
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
auto const &json_shape = GetJsonArrayShape(*tensors);
|
|
||||||
if (json_shape.size() != 2) { // 2 is data format list deep
|
|
||||||
status = INFER_STATUS(INVALID_INPUTS)
|
|
||||||
<< "the data format request is constructed illegally, expected list nesting depth 2, given "
|
|
||||||
<< json_shape.size();
|
|
||||||
MSI_LOG_ERROR << status.StatusMessage();
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
if (tensors->size() != static_cast<size_t>(request->data_size())) {
|
|
||||||
status = INFER_STATUS(INVALID_INPUTS)
|
|
||||||
<< "model input count not match, model required " << request->data_size() << ", given " << tensors->size();
|
|
||||||
MSI_LOG_ERROR << status.StatusMessage();
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < tensors->size(); i++) {
|
|
||||||
const auto &tensor = tensors->at(i);
|
|
||||||
ServingTensor request_tensor(*(request->mutable_data(i)));
|
|
||||||
auto iter = infer_type2_http_type.find(request_tensor.data_type());
|
|
||||||
if (iter == infer_type2_http_type.end()) {
|
|
||||||
ERROR_INFER_STATUS(status, FAILED, "the model input type is not supported right now");
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
HTTP_DATA_TYPE type = iter->second;
|
|
||||||
if (!tensor.is_array()) {
|
|
||||||
ERROR_INFER_STATUS(status, INVALID_INPUTS, "the tensor is constructed illegally");
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
if (tensor.empty()) {
|
|
||||||
ERROR_INFER_STATUS(status, INVALID_INPUTS, "the input tensor is null");
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
if (tensor.size() != static_cast<size_t>(request_tensor.ElementNum())) {
|
|
||||||
status = INFER_STATUS(INVALID_INPUTS) << "input " << i << " element count not match, model required "
|
|
||||||
<< request_tensor.ElementNum() << ", given " << tensor.size();
|
|
||||||
MSI_LOG_ERROR << status.StatusMessage();
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
status = GetDataFromJson(tensor, &request_tensor, 0, type);
|
|
||||||
if (status != SUCCESS) {
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status TransTensorToPredictRequest(const json &message_info, PredictRequest *const request) {
|
|
||||||
Status status(SUCCESS);
|
|
||||||
auto tensors = message_info.find(HTTP_TENSOR);
|
|
||||||
if (tensors == message_info.end()) {
|
|
||||||
ERROR_INFER_STATUS(status, INVALID_INPUTS, "http message do not have tensor type");
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
if (!tensors->is_array()) {
|
|
||||||
ERROR_INFER_STATUS(status, INVALID_INPUTS, "the input tensor list is not array");
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
if (tensors->size() != static_cast<size_t>(request->data_size())) {
|
|
||||||
status =
|
|
||||||
INFER_STATUS(INVALID_INPUTS)
|
|
||||||
<< "model input count not match or json tensor request is constructed illegally, model input count required "
|
|
||||||
<< request->data_size() << ", given " << tensors->size();
|
|
||||||
MSI_LOG_ERROR << status.StatusMessage();
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 0; i < tensors->size(); i++) {
|
|
||||||
const auto &tensor = tensors->at(i);
|
|
||||||
ServingTensor request_tensor(*(request->mutable_data(i)));
|
|
||||||
|
|
||||||
auto iter = infer_type2_http_type.find(request_tensor.data_type());
|
|
||||||
if (iter == infer_type2_http_type.end()) {
|
|
||||||
ERROR_INFER_STATUS(status, FAILED, "the model input type is not supported right now");
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
HTTP_DATA_TYPE type = iter->second;
|
|
||||||
// check data shape
|
|
||||||
auto const &json_shape = GetJsonArrayShape(tensor);
|
|
||||||
auto is_scalar_shape = [](const std::vector<int64_t> &shape) {
|
|
||||||
return shape.empty() || (shape.size() == 1 && shape[0] == 1);
|
|
||||||
};
|
|
||||||
if (is_scalar_shape(request_tensor.shape())) {
|
|
||||||
return GetScalarDataFromJson(tensor, &request_tensor, type);
|
|
||||||
} else {
|
|
||||||
if (json_shape != request_tensor.shape()) { // data shape not match
|
|
||||||
status = INFER_STATUS(INVALID_INPUTS) << "input " << i << " shape is invalid, expected "
|
|
||||||
<< request_tensor.shape() << ", given " << json_shape;
|
|
||||||
MSI_LOG_ERROR << status.StatusMessage();
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
size_t depth = 0;
|
|
||||||
size_t data_index = 0;
|
|
||||||
status = RecusiveGetTensor(tensor, depth, &request_tensor, data_index, type);
|
|
||||||
if (status != SUCCESS) {
|
|
||||||
MSI_LOG_ERROR << "Transfer tensor to predict request failed";
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status TransHTTPMsgToPredictRequest(struct evhttp_request *const http_request, PredictRequest *const request,
|
|
||||||
HTTP_TYPE *const type) {
|
|
||||||
Status status = CheckRequestValid(http_request);
|
|
||||||
if (status != SUCCESS) {
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
std::string post_message;
|
|
||||||
status = GetPostMessage(http_request, &post_message);
|
|
||||||
if (status != SUCCESS) {
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
// get model required shape
|
|
||||||
std::vector<inference::InferTensor> tensor_list;
|
|
||||||
status = Session::Instance().GetModelInputsInfo(tensor_list);
|
|
||||||
if (status != SUCCESS) {
|
|
||||||
ERROR_INFER_STATUS(status, FAILED, "get model inputs info failed");
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
for (auto &item : tensor_list) {
|
|
||||||
auto input = request->add_data();
|
|
||||||
ServingTensor tensor(*input);
|
|
||||||
tensor.set_shape(item.shape());
|
|
||||||
tensor.set_data_type(item.data_type());
|
|
||||||
int64_t element_num = tensor.ElementNum();
|
|
||||||
int64_t data_type_size = tensor.GetTypeSize(tensor.data_type());
|
|
||||||
if (element_num <= 0 || INT64_MAX / element_num < data_type_size) {
|
|
||||||
ERROR_INFER_STATUS(status, FAILED, "model shape invalid");
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
tensor.resize_data(element_num * data_type_size);
|
|
||||||
}
|
|
||||||
MSI_TIME_STAMP_START(ParseJson)
|
|
||||||
json message_info;
|
|
||||||
try {
|
|
||||||
message_info = nlohmann::json::parse(post_message);
|
|
||||||
} catch (nlohmann::json::exception &e) {
|
|
||||||
std::string json_exception = e.what();
|
|
||||||
std::string error_message = "Illegal JSON format." + json_exception;
|
|
||||||
ERROR_INFER_STATUS(status, INVALID_INPUTS, error_message);
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
MSI_TIME_STAMP_END(ParseJson)
|
|
||||||
|
|
||||||
status = CheckMessageValid(message_info, type);
|
|
||||||
if (status != SUCCESS) {
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
switch (*type) {
|
|
||||||
case TYPE_DATA:
|
|
||||||
status = TransDataToPredictRequest(message_info, request);
|
|
||||||
break;
|
|
||||||
case TYPE_TENSOR:
|
|
||||||
status = TransTensorToPredictRequest(message_info, request);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
ERROR_INFER_STATUS(status, INVALID_INPUTS, "http message must have only one type of (data, tensor)");
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GetJsonFromTensor(const ms_serving::Tensor &tensor, int len, int *const pos, json *const out_json) {
|
|
||||||
Status status(SUCCESS);
|
|
||||||
switch (tensor.tensor_type()) {
|
|
||||||
case ms_serving::MS_INT32: {
|
|
||||||
auto data = reinterpret_cast<const int *>(tensor.data().data()) + *pos;
|
|
||||||
std::vector<int32_t> result_tensor(len);
|
|
||||||
memcpy_s(result_tensor.data(), result_tensor.size() * sizeof(int32_t), data, len * sizeof(int32_t));
|
|
||||||
*out_json = std::move(result_tensor);
|
|
||||||
*pos += len;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case ms_serving::MS_FLOAT32: {
|
|
||||||
auto data = reinterpret_cast<const float *>(tensor.data().data()) + *pos;
|
|
||||||
std::vector<float> result_tensor(len);
|
|
||||||
(void)memcpy_s(result_tensor.data(), result_tensor.size() * sizeof(float), data, len * sizeof(float));
|
|
||||||
*out_json = std::move(result_tensor);
|
|
||||||
*pos += len;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
MSI_LOG(ERROR) << "the result type is not supported in restful api, type is " << tensor.tensor_type();
|
|
||||||
ERROR_INFER_STATUS(status, FAILED, "reply have unsupported type");
|
|
||||||
}
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status TransPredictReplyToData(const PredictReply &reply, json *const out_json) {
|
|
||||||
Status status(SUCCESS);
|
|
||||||
for (int i = 0; i < reply.result_size(); i++) {
|
|
||||||
(*out_json)["data"].push_back(json());
|
|
||||||
json &tensor_json = (*out_json)["data"].back();
|
|
||||||
int num = 1;
|
|
||||||
for (auto j = 0; j < reply.result(i).tensor_shape().dims_size(); j++) {
|
|
||||||
num *= reply.result(i).tensor_shape().dims(j);
|
|
||||||
}
|
|
||||||
int pos = 0;
|
|
||||||
status = GetJsonFromTensor(reply.result(i), num, &pos, &tensor_json);
|
|
||||||
if (status != SUCCESS) {
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status RecusiveGetJson(const ms_serving::Tensor &tensor, int depth, int *const pos, json *const out_json) {
|
|
||||||
Status status(SUCCESS);
|
|
||||||
if (depth >= 10) {
|
|
||||||
ERROR_INFER_STATUS(status, FAILED, "result tensor shape dims is larger than 10");
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
if (depth == tensor.tensor_shape().dims_size() - 1) {
|
|
||||||
status = GetJsonFromTensor(tensor, tensor.tensor_shape().dims(depth), pos, out_json);
|
|
||||||
if (status != SUCCESS) {
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int i = 0; i < tensor.tensor_shape().dims(depth); i++) {
|
|
||||||
out_json->push_back(json());
|
|
||||||
json &tensor_json = out_json->back();
|
|
||||||
status = RecusiveGetJson(tensor, depth + 1, pos, &tensor_json);
|
|
||||||
if (status != SUCCESS) {
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status TransPredictReplyToTensor(const PredictReply &reply, json *const out_json) {
|
|
||||||
Status status(SUCCESS);
|
|
||||||
for (int i = 0; i < reply.result_size(); i++) {
|
|
||||||
(*out_json)["tensor"].push_back(json());
|
|
||||||
json &tensor_json = (*out_json)["tensor"].back();
|
|
||||||
int pos = 0;
|
|
||||||
status = RecusiveGetJson(reply.result(i), 0, &pos, &tensor_json);
|
|
||||||
if (status != SUCCESS) {
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status TransPredictReplyToHTTPMsg(const PredictReply &reply, const HTTP_TYPE &type, struct evbuffer *const buf) {
|
|
||||||
Status status(SUCCESS);
|
|
||||||
json out_json;
|
|
||||||
switch (type) {
|
|
||||||
case TYPE_DATA:
|
|
||||||
status = TransPredictReplyToData(reply, &out_json);
|
|
||||||
break;
|
|
||||||
case TYPE_TENSOR:
|
|
||||||
status = TransPredictReplyToTensor(reply, &out_json);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
ERROR_INFER_STATUS(status, FAILED, "http message must have only one type of (data, tensor)");
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::string &out_str = out_json.dump();
|
|
||||||
evbuffer_add(buf, out_str.data(), out_str.size());
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status HttpHandleMsgDetail(struct evhttp_request *const req, void *const arg, struct evbuffer *const retbuff) {
|
|
||||||
PredictRequest request;
|
|
||||||
PredictReply reply;
|
|
||||||
HTTP_TYPE type;
|
|
||||||
MSI_TIME_STAMP_START(ParseRequest)
|
|
||||||
auto status = TransHTTPMsgToPredictRequest(req, &request, &type);
|
|
||||||
MSI_TIME_STAMP_END(ParseRequest)
|
|
||||||
if (status != SUCCESS) {
|
|
||||||
MSI_LOG(ERROR) << "restful trans to request failed";
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
MSI_TIME_STAMP_START(Predict)
|
|
||||||
status = Session::Instance().Predict(request, reply);
|
|
||||||
MSI_TIME_STAMP_END(Predict)
|
|
||||||
if (status != SUCCESS) {
|
|
||||||
MSI_LOG(ERROR) << "restful predict failed";
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
MSI_TIME_STAMP_START(CreateReplyJson)
|
|
||||||
status = TransPredictReplyToHTTPMsg(reply, type, retbuff);
|
|
||||||
MSI_TIME_STAMP_END(CreateReplyJson)
|
|
||||||
if (status != SUCCESS) {
|
|
||||||
MSI_LOG(ERROR) << "restful trans to reply failed";
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
void http_handler_msg(struct evhttp_request *const req, void *const arg) {
|
|
||||||
MSI_TIME_STAMP_START(TotalRestfulPredict)
|
|
||||||
struct evbuffer *retbuff = evbuffer_new();
|
|
||||||
if (retbuff == nullptr) {
|
|
||||||
MSI_LOG_ERROR << "Create event buffer failed";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto status = HttpHandleMsgDetail(req, arg, retbuff);
|
|
||||||
if (status != SUCCESS) {
|
|
||||||
ErrorMessage(req, status);
|
|
||||||
evbuffer_free(retbuff);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
MSI_TIME_STAMP_START(ReplyJson)
|
|
||||||
evhttp_send_reply(req, HTTP_OK, "Client", retbuff);
|
|
||||||
MSI_TIME_STAMP_END(ReplyJson)
|
|
||||||
evbuffer_free(retbuff);
|
|
||||||
MSI_TIME_STAMP_END(TotalRestfulPredict)
|
|
||||||
}
|
|
||||||
} // namespace serving
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,29 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_SERVING_HTTP_PROCESS_H
|
|
||||||
#define MINDSPORE_SERVING_HTTP_PROCESS_H
|
|
||||||
|
|
||||||
#include <evhttp.h>
|
|
||||||
#include <event.h>
|
|
||||||
#include <event2/http.h>
|
|
||||||
#include <event2/http_struct.h>
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace serving {
|
|
||||||
void http_handler_msg(struct evhttp_request *req, void *arg);
|
|
||||||
} // namespace serving
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // MINDSPORE_SERVER_H
|
|
|
@ -1,273 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "core/server.h"
|
|
||||||
#include <evhttp.h>
|
|
||||||
#include <event.h>
|
|
||||||
#include <event2/thread.h>
|
|
||||||
#include <event2/listener.h>
|
|
||||||
#include <grpcpp/grpcpp.h>
|
|
||||||
#include <grpcpp/health_check_service_interface.h>
|
|
||||||
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
|
||||||
#include <future>
|
|
||||||
#include <memory>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include <utility>
|
|
||||||
#include <atomic>
|
|
||||||
#include "include/infer_log.h"
|
|
||||||
#include "serving/ms_service.grpc.pb.h"
|
|
||||||
#include "core/util/option_parser.h"
|
|
||||||
#include "core/version_control/version_controller.h"
|
|
||||||
#include "core/session.h"
|
|
||||||
#include "core/serving_tensor.h"
|
|
||||||
#include "core/http_process.h"
|
|
||||||
|
|
||||||
using ms_serving::MSService;
|
|
||||||
using ms_serving::PredictReply;
|
|
||||||
using ms_serving::PredictRequest;
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace serving {
|
|
||||||
namespace {
|
|
||||||
static const uint32_t uint32max = 0x7FFFFFFF;
|
|
||||||
std::promise<void> exit_requested;
|
|
||||||
std::atomic_flag has_exited = ATOMIC_FLAG_INIT;
|
|
||||||
|
|
||||||
static const char kServerHttpIp[] = "0.0.0.0";
|
|
||||||
|
|
||||||
void ClearEnv() { Session::Instance().Clear(); }
|
|
||||||
void HandleSignal(int sig) {
|
|
||||||
if (!has_exited.test_and_set()) {
|
|
||||||
exit_requested.set_value();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
grpc::Status CreatGRPCStatus(const Status &status) {
|
|
||||||
switch (status.StatusCode()) {
|
|
||||||
case SUCCESS:
|
|
||||||
return grpc::Status::OK;
|
|
||||||
case FAILED:
|
|
||||||
return grpc::Status::CANCELLED;
|
|
||||||
case INVALID_INPUTS: {
|
|
||||||
auto status_msg = status.StatusMessage();
|
|
||||||
if (status_msg.empty()) {
|
|
||||||
status_msg = "The Predict Inputs do not match the Model Request!";
|
|
||||||
}
|
|
||||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, status_msg);
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return grpc::Status::CANCELLED;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// Service Implement
|
|
||||||
class MSServiceImpl final : public MSService::Service {
|
|
||||||
grpc::Status Predict(grpc::ServerContext *context, const PredictRequest *request, PredictReply *reply) override {
|
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
|
||||||
MSI_TIME_STAMP_START(Predict)
|
|
||||||
auto res = Session::Instance().Predict(*request, *reply);
|
|
||||||
MSI_TIME_STAMP_END(Predict)
|
|
||||||
if (res != inference::SUCCESS) {
|
|
||||||
return CreatGRPCStatus(res);
|
|
||||||
}
|
|
||||||
MSI_LOG(INFO) << "Finish call service Eval";
|
|
||||||
return grpc::Status::OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
grpc::Status Test(grpc::ServerContext *context, const PredictRequest *request, PredictReply *reply) override {
|
|
||||||
MSI_LOG(INFO) << "TestService call";
|
|
||||||
return grpc::Status::OK;
|
|
||||||
}
|
|
||||||
std::mutex mutex_;
|
|
||||||
};
|
|
||||||
|
|
||||||
static std::pair<struct evhttp *, struct event_base *> NewHttpServer() {
|
|
||||||
auto option_args = Options::Instance().GetArgs();
|
|
||||||
int32_t http_port = option_args->rest_api_port;
|
|
||||||
// init http server
|
|
||||||
event_init();
|
|
||||||
evthread_use_pthreads();
|
|
||||||
struct event_base *eb = event_base_new();
|
|
||||||
if (eb == nullptr) {
|
|
||||||
MSI_LOG(ERROR) << "Serving Error: RESTful server start failed, new http event failed";
|
|
||||||
std::cout << "Serving Error: RESTful server start failed, new http event failed" << std::endl;
|
|
||||||
return std::make_pair(nullptr, nullptr);
|
|
||||||
}
|
|
||||||
struct evhttp *http_server = evhttp_new(eb);
|
|
||||||
if (http_server == nullptr) {
|
|
||||||
MSI_LOG(ERROR) << "Serving Error: RESTful server start failed, create http server faild";
|
|
||||||
std::cout << "Serving Error: RESTful server start failed, create http server faild" << std::endl;
|
|
||||||
event_base_free(eb);
|
|
||||||
return std::make_pair(nullptr, nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct sockaddr_in sin = {};
|
|
||||||
sin.sin_family = AF_INET;
|
|
||||||
sin.sin_port = htons(http_port);
|
|
||||||
auto listener =
|
|
||||||
evconnlistener_new_bind(eb, nullptr, nullptr, LEV_OPT_REUSEABLE | LEV_OPT_CLOSE_ON_EXEC | LEV_OPT_CLOSE_ON_FREE, -1,
|
|
||||||
reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin));
|
|
||||||
if (listener == nullptr) {
|
|
||||||
MSI_LOG_ERROR << "Serving Error: RESTful server start failed, create http listener faild, port " << http_port;
|
|
||||||
std::cout << "Serving Error: RESTful server start failed, create http listener faild, port " << http_port
|
|
||||||
<< std::endl;
|
|
||||||
event_base_free(eb);
|
|
||||||
evhttp_free(http_server);
|
|
||||||
return std::make_pair(nullptr, nullptr);
|
|
||||||
}
|
|
||||||
auto bound = evhttp_bind_listener(http_server, listener);
|
|
||||||
if (bound == nullptr) {
|
|
||||||
MSI_LOG_ERROR << "Serving Error: RESTful server start failed, bind http listener to server faild, port "
|
|
||||||
<< http_port;
|
|
||||||
std::cout << "Serving Error: RESTful server start failed, bind http listener to server faild, port " << http_port
|
|
||||||
<< std::endl;
|
|
||||||
evconnlistener_free(listener);
|
|
||||||
event_base_free(eb);
|
|
||||||
evhttp_free(http_server);
|
|
||||||
return std::make_pair(nullptr, nullptr);
|
|
||||||
}
|
|
||||||
return std::make_pair(http_server, eb);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status BuildAndStartModelInner() {
|
|
||||||
Status res;
|
|
||||||
auto option_args = Options::Instance().GetArgs();
|
|
||||||
std::string model_path = option_args->model_path;
|
|
||||||
std::string model_name = option_args->model_name;
|
|
||||||
std::string device_type = option_args->device_type;
|
|
||||||
auto device_id = option_args->device_id;
|
|
||||||
res = Session::Instance().CreatDeviceSession(device_type, device_id);
|
|
||||||
if (res != SUCCESS) {
|
|
||||||
MSI_LOG(ERROR) << "Serving Error: create inference session failed, device type " << device_type << " device id "
|
|
||||||
<< device_id;
|
|
||||||
std::cout << "Serving Error: create inference session failed, device type " << device_type << " device id "
|
|
||||||
<< device_id << std::endl;
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
VersionController version_controller(option_args->poll_model_wait_seconds, model_path, model_name);
|
|
||||||
res = version_controller.Run();
|
|
||||||
if (res != SUCCESS) {
|
|
||||||
MSI_LOG(ERROR) << "Serving Error: load model failed, model directory " << option_args->model_path << " model name "
|
|
||||||
<< option_args->model_name;
|
|
||||||
std::cout << "Serving Error: load model failed, model directory " << option_args->model_path << " model name "
|
|
||||||
<< option_args->model_name << std::endl;
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status BuildAndStartModel() {
|
|
||||||
try {
|
|
||||||
auto status = BuildAndStartModelInner();
|
|
||||||
return status;
|
|
||||||
} catch (const std::bad_alloc &ex) {
|
|
||||||
MSI_LOG(ERROR) << "Serving Error: malloc memory failed";
|
|
||||||
std::cout << "Serving Error: malloc memory failed" << std::endl;
|
|
||||||
} catch (const std::runtime_error &ex) {
|
|
||||||
MSI_LOG(ERROR) << "Serving Error: runtime error occurred: " << ex.what();
|
|
||||||
std::cout << "Serving Error: runtime error occurred: " << ex.what() << std::endl;
|
|
||||||
} catch (const std::exception &ex) {
|
|
||||||
MSI_LOG(ERROR) << "Serving Error: exception occurred: " << ex.what();
|
|
||||||
std::cout << "Serving Error: exception occurred: " << ex.what() << std::endl;
|
|
||||||
} catch (...) {
|
|
||||||
MSI_LOG(ERROR) << "Serving Error: exception occurred";
|
|
||||||
std::cout << "Serving Error: exception occurred";
|
|
||||||
}
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Server::BuildAndStart() {
|
|
||||||
// handle exit signal
|
|
||||||
signal(SIGINT, HandleSignal);
|
|
||||||
signal(SIGTERM, HandleSignal);
|
|
||||||
Status res = BuildAndStartModel();
|
|
||||||
if (res != SUCCESS) {
|
|
||||||
ClearEnv();
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
auto option_args = Options::Instance().GetArgs();
|
|
||||||
std::string server_address = std::string(kServerHttpIp) + ":" + std::to_string(option_args->grpc_port);
|
|
||||||
|
|
||||||
auto http_server_new_ret = NewHttpServer();
|
|
||||||
struct evhttp *http_server = http_server_new_ret.first;
|
|
||||||
struct event_base *eb = http_server_new_ret.second;
|
|
||||||
if (http_server == nullptr || eb == nullptr) {
|
|
||||||
MSI_LOG(ERROR) << "Serving Error: RESTful server start failed";
|
|
||||||
std::cout << "Serving Error: RESTful server start failed" << std::endl;
|
|
||||||
ClearEnv();
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
auto exit_http = [eb, http_server]() {
|
|
||||||
evhttp_free(http_server);
|
|
||||||
event_base_free(eb);
|
|
||||||
};
|
|
||||||
int32_t http_port = option_args->rest_api_port;
|
|
||||||
std::string http_addr = kServerHttpIp;
|
|
||||||
|
|
||||||
evhttp_set_timeout(http_server, 60);
|
|
||||||
evhttp_set_gencb(http_server, http_handler_msg, nullptr);
|
|
||||||
|
|
||||||
// grpc server
|
|
||||||
MSServiceImpl ms_service;
|
|
||||||
grpc::EnableDefaultHealthCheckService(true);
|
|
||||||
grpc::reflection::InitProtoReflectionServerBuilderPlugin();
|
|
||||||
// Set the port is not reuseable
|
|
||||||
auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0);
|
|
||||||
grpc::ServerBuilder serverBuilder;
|
|
||||||
serverBuilder.SetOption(std::move(option));
|
|
||||||
serverBuilder.SetMaxMessageSize(uint32max);
|
|
||||||
serverBuilder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
|
||||||
serverBuilder.RegisterService(&ms_service);
|
|
||||||
std::unique_ptr<grpc::Server> server(serverBuilder.BuildAndStart());
|
|
||||||
if (server == nullptr) {
|
|
||||||
MSI_LOG(ERROR) << "Serving Error: create server failed, gRPC address " << server_address << ", RESTful address "
|
|
||||||
<< http_addr << ":" << http_port << ", model directory " << option_args->model_path << " model name "
|
|
||||||
<< option_args->model_name << ", device type " << option_args->device_type << ", device id "
|
|
||||||
<< option_args->device_id;
|
|
||||||
std::cout << "Serving Error: create server failed, gRPC address " << server_address << ", RESTful address "
|
|
||||||
<< http_addr << ":" << http_port << ", model directory " << option_args->model_path << " model name "
|
|
||||||
<< option_args->model_name << ", device type " << option_args->device_type << ", device id "
|
|
||||||
<< option_args->device_id << std::endl;
|
|
||||||
ClearEnv();
|
|
||||||
exit_http();
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
auto grpc_server_run = [&server, &server_address]() {
|
|
||||||
MSI_LOG(INFO) << "MS Serving grpc listening on " << server_address;
|
|
||||||
std::cout << "Serving: MS Serving gRPC start success, listening on " << server_address << std::endl;
|
|
||||||
server->Wait();
|
|
||||||
};
|
|
||||||
auto http_server_run = [&eb, &http_addr, &http_port]() {
|
|
||||||
MSI_LOG(INFO) << "MS Serving restful listening on " << http_addr << ":" << http_port;
|
|
||||||
std::cout << "Serving: MS Serving RESTful start success, listening on " << http_addr << ":" << http_port
|
|
||||||
<< std::endl;
|
|
||||||
event_base_dispatch(eb);
|
|
||||||
};
|
|
||||||
std::thread grpc_thread(grpc_server_run);
|
|
||||||
std::thread restful_thread(http_server_run);
|
|
||||||
auto exit_future = exit_requested.get_future();
|
|
||||||
exit_future.wait();
|
|
||||||
ClearEnv();
|
|
||||||
server->Shutdown();
|
|
||||||
event_base_loopexit(eb, nullptr);
|
|
||||||
grpc_thread.join();
|
|
||||||
restful_thread.join();
|
|
||||||
exit_http();
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
} // namespace serving
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,30 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_SERVER_H
|
|
||||||
#define MINDSPORE_SERVER_H
|
|
||||||
|
|
||||||
#include "util/status.h"
|
|
||||||
namespace mindspore {
|
|
||||||
namespace serving {
|
|
||||||
class Server {
|
|
||||||
public:
|
|
||||||
Server() = default;
|
|
||||||
~Server() = default;
|
|
||||||
Status BuildAndStart();
|
|
||||||
};
|
|
||||||
} // namespace serving
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // MINDSPORE_SERVER_H
|
|
|
@ -1,194 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "core/serving_tensor.h"
|
|
||||||
#include <vector>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <string>
|
|
||||||
#include <algorithm>
|
|
||||||
#include "include/infer_log.h"
|
|
||||||
|
|
||||||
using std::string;
|
|
||||||
using std::unordered_map;
|
|
||||||
using std::vector;
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace serving {
|
|
||||||
using inference::DataType;
|
|
||||||
using inference::InferTensorBase;
|
|
||||||
|
|
||||||
const size_t kMaxShapeElementCount = INT32_MAX;
|
|
||||||
const size_t kMaxDataBufferSize = UINT32_MAX;
|
|
||||||
|
|
||||||
ServingTensor::ServingTensor(ms_serving::Tensor &other) : tensor_(other) {}
|
|
||||||
|
|
||||||
ServingTensor::~ServingTensor() {}
|
|
||||||
|
|
||||||
DataType ServingTensor::data_type() const {
|
|
||||||
const std::unordered_map<ms_serving::DataType, inference::DataType> type2id_map{
|
|
||||||
{ms_serving::MS_UNKNOWN, inference::kMSI_Unknown}, {ms_serving::MS_BOOL, inference::kMSI_Bool},
|
|
||||||
{ms_serving::MS_INT8, inference::kMSI_Int8}, {ms_serving::MS_UINT8, inference::kMSI_Uint8},
|
|
||||||
{ms_serving::MS_INT16, inference::kMSI_Int16}, {ms_serving::MS_UINT16, inference::kMSI_Uint16},
|
|
||||||
{ms_serving::MS_INT32, inference::kMSI_Int32}, {ms_serving::MS_UINT32, inference::kMSI_Uint32},
|
|
||||||
{ms_serving::MS_INT64, inference::kMSI_Int64}, {ms_serving::MS_UINT64, inference::kMSI_Uint64},
|
|
||||||
{ms_serving::MS_FLOAT16, inference::kMSI_Float16}, {ms_serving::MS_FLOAT32, inference::kMSI_Float32},
|
|
||||||
{ms_serving::MS_FLOAT64, inference::kMSI_Float64},
|
|
||||||
};
|
|
||||||
auto it = type2id_map.find(tensor_.tensor_type());
|
|
||||||
if (it == type2id_map.end()) {
|
|
||||||
MSI_LOG_WARNING << "failed to get data type, undefined data type " << tensor_.tensor_type();
|
|
||||||
return inference::kMSI_Unknown;
|
|
||||||
} else {
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ServingTensor::set_data_type(DataType data_type) {
|
|
||||||
const std::unordered_map<inference::DataType, ms_serving::DataType> id2type_map{
|
|
||||||
{inference::kMSI_Unknown, ms_serving::MS_UNKNOWN}, {inference::kMSI_Bool, ms_serving::MS_BOOL},
|
|
||||||
{inference::kMSI_Float64, ms_serving::MS_FLOAT64}, {inference::kMSI_Int8, ms_serving::MS_INT8},
|
|
||||||
{inference::kMSI_Uint8, ms_serving::MS_UINT8}, {inference::kMSI_Int16, ms_serving::MS_INT16},
|
|
||||||
{inference::kMSI_Uint16, ms_serving::MS_UINT16}, {inference::kMSI_Int32, ms_serving::MS_INT32},
|
|
||||||
{inference::kMSI_Uint32, ms_serving::MS_UINT32}, {inference::kMSI_Int64, ms_serving::MS_INT64},
|
|
||||||
{inference::kMSI_Uint64, ms_serving::MS_UINT64}, {inference::kMSI_Float16, ms_serving::MS_FLOAT16},
|
|
||||||
{inference::kMSI_Float32, ms_serving::MS_FLOAT32},
|
|
||||||
};
|
|
||||||
auto it = id2type_map.find(data_type);
|
|
||||||
if (it == id2type_map.end()) {
|
|
||||||
MSI_LOG_WARNING << "failed to set data type, undefined data type " << data_type;
|
|
||||||
tensor_.set_tensor_type(ms_serving::MS_UNKNOWN);
|
|
||||||
} else {
|
|
||||||
tensor_.set_tensor_type(it->second);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int64_t> ServingTensor::shape() const {
|
|
||||||
std::vector<int64_t> result;
|
|
||||||
auto dims = tensor_.tensor_shape().dims();
|
|
||||||
std::transform(dims.begin(), dims.end(), std::back_inserter(result), [](const int64_t dim) { return dim; });
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
void ServingTensor::set_shape(const std::vector<int64_t> &shape) {
|
|
||||||
auto tensor_shape = tensor_.mutable_tensor_shape();
|
|
||||||
tensor_shape->Clear();
|
|
||||||
size_t element_count = 1;
|
|
||||||
for (auto dim : shape) {
|
|
||||||
if (dim <= 0 || element_count > kMaxShapeElementCount / dim) {
|
|
||||||
MSI_LOG_ERROR << "failed to set shape, invalid dim num " << dim;
|
|
||||||
tensor_shape->Clear();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
element_count *= dim;
|
|
||||||
tensor_shape->add_dims(dim);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ServingTensor::resize_data(size_t data_len) {
|
|
||||||
string *buffer = tensor_.mutable_data();
|
|
||||||
if (buffer == nullptr) {
|
|
||||||
MSI_LOG_ERROR << "invalid buffer data";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
buffer->resize(data_len);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t ServingTensor::data_size() const { return tensor_.data().size(); }
|
|
||||||
|
|
||||||
void *ServingTensor::mutable_data() { return const_cast<char *>(tensor_.mutable_data()->data()); }
|
|
||||||
|
|
||||||
const void *ServingTensor::data() const { return tensor_.data().data(); }
|
|
||||||
|
|
||||||
ServingRequest::ServingRequest(const ms_serving::PredictRequest &request) : request_(request) {
|
|
||||||
auto &data = request_.data();
|
|
||||||
std::transform(data.begin(), data.end(), std::back_inserter(cache_),
|
|
||||||
[](const ms_serving::Tensor &item) { return ServingTensor(const_cast<ms_serving::Tensor &>(item)); });
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t ServingRequest::size() const { return cache_.size(); }
|
|
||||||
|
|
||||||
const InferTensorBase *ServingRequest::operator[](size_t index) const {
|
|
||||||
if (index >= cache_.size()) {
|
|
||||||
MSI_LOG_ERROR << "visit invalid index " << index << " total size " << cache_.size();
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return &(cache_[index]);
|
|
||||||
}
|
|
||||||
|
|
||||||
ServingImages::ServingImages(const ms_serving::Images &images) : images_(images) {}
|
|
||||||
|
|
||||||
size_t ServingImages::batch_size() const { return images_.images_size(); }
|
|
||||||
|
|
||||||
bool ServingImages::get(size_t index, const void *&pic_buffer, uint32_t &pic_size) const {
|
|
||||||
if (index >= static_cast<size_t>(images_.images_size())) {
|
|
||||||
MSI_LOG_ERROR << "visit invalid index " << index << " total size " << images_.images_size();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
pic_buffer = images_.images(index).data();
|
|
||||||
pic_size = images_.images(index).size();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t ServingImages::input_index() const { return static_cast<size_t>(images_.input_index()); }
|
|
||||||
|
|
||||||
size_t ServingReply::size() const { return cache_.size(); }
|
|
||||||
|
|
||||||
InferTensorBase *ServingReply::operator[](size_t index) {
|
|
||||||
if (index >= cache_.size()) {
|
|
||||||
MSI_LOG_ERROR << "visit invalid index " << index << " total size " << cache_.size();
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return &(cache_[index]);
|
|
||||||
}
|
|
||||||
|
|
||||||
const InferTensorBase *ServingReply::operator[](size_t index) const {
|
|
||||||
if (index >= cache_.size()) {
|
|
||||||
MSI_LOG_ERROR << "visit invalid index " << index << " total size " << cache_.size();
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return &(cache_[index]);
|
|
||||||
}
|
|
||||||
|
|
||||||
InferTensorBase *ServingReply::add() {
|
|
||||||
auto new_item = reply_.add_result();
|
|
||||||
if (new_item == nullptr) {
|
|
||||||
MSI_LOG_ERROR << "add new item failed, current total size " << cache_.size();
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
cache_.push_back(ServingTensor(*new_item));
|
|
||||||
return &(cache_.back());
|
|
||||||
}
|
|
||||||
|
|
||||||
void ServingReply::clear() { reply_.mutable_result()->Clear(); }
|
|
||||||
|
|
||||||
ServingImagesRequest::ServingImagesRequest(const ms_serving::PredictRequest &request) : request_(request) {
|
|
||||||
auto &images_inputs = request_.images();
|
|
||||||
std::transform(images_inputs.begin(), images_inputs.end(), std::back_inserter(cache_),
|
|
||||||
[](const ms_serving::Images &item) { return ServingImages(const_cast<ms_serving::Images &>(item)); });
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t ServingImagesRequest::size() const { return cache_.size(); }
|
|
||||||
|
|
||||||
const inference::InferImagesBase *ServingImagesRequest::operator[](size_t index) const {
|
|
||||||
if (index >= cache_.size()) {
|
|
||||||
MSI_LOG_ERROR << "visit invalid index " << index << " total size " << cache_.size();
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return &(cache_[index]);
|
|
||||||
}
|
|
||||||
} // namespace serving
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,105 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_SERVING_TENSOR_H_
|
|
||||||
#define MINDSPORE_SERVING_TENSOR_H_
|
|
||||||
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
#include "include/infer_tensor.h"
|
|
||||||
#include "serving/ms_service.pb.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace serving {
|
|
||||||
class MS_API ServingTensor : public inference::InferTensorBase {
|
|
||||||
public:
|
|
||||||
// the other's lifetime must longer than this object
|
|
||||||
explicit ServingTensor(ms_serving::Tensor &other);
|
|
||||||
~ServingTensor();
|
|
||||||
|
|
||||||
inference::DataType data_type() const override;
|
|
||||||
void set_data_type(inference::DataType type) override;
|
|
||||||
std::vector<int64_t> shape() const override;
|
|
||||||
void set_shape(const std::vector<int64_t> &shape) override;
|
|
||||||
const void *data() const override;
|
|
||||||
size_t data_size() const override;
|
|
||||||
bool resize_data(size_t data_len) override;
|
|
||||||
void *mutable_data() override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
// if tensor_ is reference from other ms_serving::Tensor, the other's lifetime must
|
|
||||||
// longer than this object
|
|
||||||
ms_serving::Tensor &tensor_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class ServingImages : public inference::InferImagesBase {
|
|
||||||
public:
|
|
||||||
explicit ServingImages(const ms_serving::Images &images);
|
|
||||||
~ServingImages() = default;
|
|
||||||
|
|
||||||
size_t batch_size() const override;
|
|
||||||
bool get(size_t index, const void *&pic_buffer, uint32_t &pic_size) const override;
|
|
||||||
size_t input_index() const override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const ms_serving::Images &images_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class ServingRequest : public inference::RequestBase {
|
|
||||||
public:
|
|
||||||
explicit ServingRequest(const ms_serving::PredictRequest &request);
|
|
||||||
~ServingRequest() = default;
|
|
||||||
|
|
||||||
size_t size() const override;
|
|
||||||
const inference::InferTensorBase *operator[](size_t index) const override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const ms_serving::PredictRequest &request_;
|
|
||||||
std::vector<ServingTensor> cache_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class ServingReply : public inference::ReplyBase {
|
|
||||||
public:
|
|
||||||
explicit ServingReply(ms_serving::PredictReply &reply) : reply_(reply) {}
|
|
||||||
~ServingReply() = default;
|
|
||||||
|
|
||||||
size_t size() const override;
|
|
||||||
inference::InferTensorBase *operator[](size_t index) override;
|
|
||||||
const inference::InferTensorBase *operator[](size_t index) const override;
|
|
||||||
inference::InferTensorBase *add() override;
|
|
||||||
void clear() override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
ms_serving::PredictReply &reply_;
|
|
||||||
std::vector<ServingTensor> cache_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class ServingImagesRequest : public inference::ImagesRequestBase {
|
|
||||||
public:
|
|
||||||
explicit ServingImagesRequest(const ms_serving::PredictRequest &request);
|
|
||||||
~ServingImagesRequest() = default;
|
|
||||||
|
|
||||||
size_t size() const override;
|
|
||||||
const inference::InferImagesBase *operator[](size_t index) const override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const ms_serving::PredictRequest &request_;
|
|
||||||
std::vector<ServingImages> cache_;
|
|
||||||
};
|
|
||||||
} // namespace serving
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // MINDSPORE_SERVING_TENSOR_H_
|
|
|
@ -1,154 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "core/session.h"
|
|
||||||
#include <grpcpp/grpcpp.h>
|
|
||||||
#include <string>
|
|
||||||
#include <map>
|
|
||||||
#include <vector>
|
|
||||||
#include <utility>
|
|
||||||
#include <memory>
|
|
||||||
#include <chrono>
|
|
||||||
|
|
||||||
#include "include/infer_log.h"
|
|
||||||
#include "serving/ms_service.grpc.pb.h"
|
|
||||||
#include "core/util/option_parser.h"
|
|
||||||
#include "core/version_control/version_controller.h"
|
|
||||||
#include "core/util/file_system_operation.h"
|
|
||||||
#include "core/serving_tensor.h"
|
|
||||||
|
|
||||||
using ms_serving::MSService;
|
|
||||||
using ms_serving::PredictReply;
|
|
||||||
using ms_serving::PredictRequest;
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace serving {
|
|
||||||
Status Session::CreatDeviceSession(const std::string &device, uint32_t device_id) {
|
|
||||||
session_ = inference::InferSession::CreateSession(device, device_id);
|
|
||||||
if (session_ == nullptr) {
|
|
||||||
MSI_LOG(ERROR) << "Creat Session Failed";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
device_type_ = device;
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Session &Session::Instance() {
|
|
||||||
static Session instance;
|
|
||||||
return instance;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Session::Predict(const PredictRequest &request, PredictReply &reply) {
|
|
||||||
try {
|
|
||||||
auto status = PredictInner(request, reply);
|
|
||||||
return status;
|
|
||||||
} catch (const std::bad_alloc &ex) {
|
|
||||||
MSI_LOG(ERROR) << "Serving Error: malloc memory failed";
|
|
||||||
std::cout << "Serving Error: malloc memory failed" << std::endl;
|
|
||||||
} catch (const std::runtime_error &ex) {
|
|
||||||
MSI_LOG(ERROR) << "Serving Error: runtime error occurred: " << ex.what();
|
|
||||||
std::cout << "Serving Error: runtime error occurred: " << ex.what() << std::endl;
|
|
||||||
} catch (const std::exception &ex) {
|
|
||||||
MSI_LOG(ERROR) << "Serving Error: exception occurred: " << ex.what();
|
|
||||||
std::cout << "Serving Error: exception occurred: " << ex.what() << std::endl;
|
|
||||||
} catch (...) {
|
|
||||||
MSI_LOG(ERROR) << "Serving Error: exception occurred";
|
|
||||||
std::cout << "Serving Error: exception occurred";
|
|
||||||
}
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Session::PredictInner(const PredictRequest &request, PredictReply &reply) {
|
|
||||||
if (!model_loaded_) {
|
|
||||||
MSI_LOG(ERROR) << "the model has not loaded";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
if (session_ == nullptr) {
|
|
||||||
MSI_LOG(ERROR) << "the inference session has not be initialized";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
|
||||||
MSI_LOG(INFO) << "run Predict";
|
|
||||||
|
|
||||||
if (request.images_size() > 0) {
|
|
||||||
ServingImagesRequest serving_images(request);
|
|
||||||
ServingRequest serving_request(request);
|
|
||||||
ServingReply serving_reply(reply);
|
|
||||||
Status ret = session_->ExecuteModel(graph_id_, serving_images, serving_request, serving_reply);
|
|
||||||
if (ret != SUCCESS) {
|
|
||||||
MSI_LOG(ERROR) << "execute model with images return failed";
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
} else if (request.data_size() > 0) {
|
|
||||||
ServingRequest serving_request(request);
|
|
||||||
ServingReply serving_reply(reply);
|
|
||||||
Status ret = session_->ExecuteModel(graph_id_, serving_request, serving_reply);
|
|
||||||
if (ret != SUCCESS) {
|
|
||||||
MSI_LOG(ERROR) << "execute model with datas return failed";
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
MSI_LOG(INFO) << "run Predict finished";
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Session::Warmup(const MindSporeModelPtr model) {
|
|
||||||
if (session_ == nullptr) {
|
|
||||||
MSI_LOG(ERROR) << "The CreatDeviceSession should be called, before warmup";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
|
||||||
std::string file_name = model->GetModelPath() + '/' + model->GetModelName();
|
|
||||||
model_loaded_ = false;
|
|
||||||
MSI_TIME_STAMP_START(LoadModelFromFile)
|
|
||||||
auto ret = session_->LoadModelFromFile(file_name, graph_id_);
|
|
||||||
MSI_TIME_STAMP_END(LoadModelFromFile)
|
|
||||||
if (ret != SUCCESS) {
|
|
||||||
MSI_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
model_loaded_ = true;
|
|
||||||
MSI_LOG(INFO) << "Session Warmup finished";
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Session::Clear() {
|
|
||||||
if (session_ != nullptr) {
|
|
||||||
session_->UnloadModel(graph_id_);
|
|
||||||
session_->FinalizeEnv();
|
|
||||||
session_ = nullptr;
|
|
||||||
}
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Session::GetModelInputsInfo(std::vector<inference::InferTensor> &tensor_list) {
|
|
||||||
if (!model_loaded_) {
|
|
||||||
MSI_LOG(ERROR) << "the model has not loaded";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
if (session_ == nullptr) {
|
|
||||||
MSI_LOG(ERROR) << "the inference session has not be initialized";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
|
||||||
Status ret = session_->GetModelInputsInfo(graph_id_, &tensor_list);
|
|
||||||
if (ret != SUCCESS) {
|
|
||||||
MSI_LOG(ERROR) << "get model inputs info failed";
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
} // namespace serving
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,61 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_SERVING_SESSION_H
|
|
||||||
#define MINDSPORE_SERVING_SESSION_H
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <mutex>
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
#include "util/status.h"
|
|
||||||
#include "version_control/model.h"
|
|
||||||
#include "include/inference.h"
|
|
||||||
#include "serving/ms_service.pb.h"
|
|
||||||
#include "serving/ms_service.grpc.pb.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace serving {
|
|
||||||
using inference::FAILED;
|
|
||||||
using inference::INVALID_INPUTS;
|
|
||||||
using inference::Status;
|
|
||||||
using inference::SUCCESS;
|
|
||||||
using ms_serving::PredictReply;
|
|
||||||
using ms_serving::PredictRequest;
|
|
||||||
|
|
||||||
class Session {
|
|
||||||
public:
|
|
||||||
static Session &Instance();
|
|
||||||
Status CreatDeviceSession(const std::string &device, uint32_t device_id);
|
|
||||||
Status Predict(const PredictRequest &request, PredictReply &reply);
|
|
||||||
Status Warmup(const MindSporeModelPtr model);
|
|
||||||
Status Clear();
|
|
||||||
Status GetModelInputsInfo(std::vector<inference::InferTensor> &tensor_list);
|
|
||||||
|
|
||||||
private:
|
|
||||||
Session() = default;
|
|
||||||
~Session() = default;
|
|
||||||
int sesseion_id_{0};
|
|
||||||
std::shared_ptr<inference::InferSession> session_{nullptr};
|
|
||||||
bool model_loaded_ = false;
|
|
||||||
uint32_t graph_id_{0};
|
|
||||||
std::mutex mutex_;
|
|
||||||
std::string device_type_;
|
|
||||||
|
|
||||||
Status PredictInner(const PredictRequest &request, PredictReply &reply);
|
|
||||||
};
|
|
||||||
} // namespace serving
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // MINDSPORE_SERVER_H
|
|
|
@ -1,67 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "core/util/file_system_operation.h"
|
|
||||||
#include <unistd.h>
|
|
||||||
#include <dirent.h>
|
|
||||||
#include <sys/types.h>
|
|
||||||
#include <sys/stat.h>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include <iostream>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <ctime>
|
|
||||||
#include <fstream>
|
|
||||||
#include <memory>
|
|
||||||
#include "include/infer_log.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace serving {
|
|
||||||
bool DirOrFileExist(const std::string &file_path) {
|
|
||||||
int ret = access(file_path.c_str(), 0);
|
|
||||||
return (ret == -1) ? false : true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::string> GetAllSubDirs(const std::string &dir_path) {
|
|
||||||
DIR *dir = nullptr;
|
|
||||||
struct dirent *ptr = nullptr;
|
|
||||||
std::vector<std::string> SubDirs;
|
|
||||||
|
|
||||||
if ((dir = opendir(dir_path.c_str())) == NULL) {
|
|
||||||
MSI_LOG(ERROR) << "Open " << dir_path << " error!";
|
|
||||||
return std::vector<std::string>();
|
|
||||||
}
|
|
||||||
|
|
||||||
while ((ptr = readdir(dir)) != NULL) {
|
|
||||||
std::string name = ptr->d_name;
|
|
||||||
if (name == "." || name == "..") {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (ptr->d_type == DT_DIR) {
|
|
||||||
SubDirs.push_back(dir_path + "/" + name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
closedir(dir);
|
|
||||||
std::sort(SubDirs.begin(), SubDirs.end());
|
|
||||||
return SubDirs;
|
|
||||||
}
|
|
||||||
|
|
||||||
time_t GetModifyTime(const std::string &file_path) {
|
|
||||||
struct stat info;
|
|
||||||
(void)stat(file_path.c_str(), &info);
|
|
||||||
return info.st_mtime;
|
|
||||||
}
|
|
||||||
} // namespace serving
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,32 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_SERVING_FILE_SYSTEM_OPERATION_H_
|
|
||||||
#define MINDSPORE_SERVING_FILE_SYSTEM_OPERATION_H_
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include <ctime>
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace serving {
|
|
||||||
char *ReadFile(const char *file, size_t *size);
|
|
||||||
bool DirOrFileExist(const std::string &file_path);
|
|
||||||
std::vector<std::string> GetAllSubDirs(const std::string &dir_path);
|
|
||||||
time_t GetModifyTime(const std::string &file_path);
|
|
||||||
} // namespace serving
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // !MINDSPORE_SERVING_FILE_SYSTEM_OPERATION_H_
|
|
|
@ -1,259 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "core/util/option_parser.h"
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <cstring>
|
|
||||||
#include <iostream>
|
|
||||||
#include <iomanip>
|
|
||||||
#include "include/infer_log.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace serving {
|
|
||||||
bool StartWith(const std::string &str, const std::string &expected) {
|
|
||||||
return expected.empty() ||
|
|
||||||
(str.size() >= expected.size() && memcmp(str.data(), expected.data(), expected.size()) == 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool RemovePrefix(std::string *const str, const std::string &prefix) {
|
|
||||||
if (!StartWith(*str, prefix)) return false;
|
|
||||||
str->replace(str->begin(), str->begin() + prefix.size(), "");
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Option::ParseInt32(std::string *const arg) {
|
|
||||||
if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) {
|
|
||||||
int32_t parsed_value;
|
|
||||||
try {
|
|
||||||
parsed_value = std::stoi(arg->data());
|
|
||||||
} catch (std::invalid_argument) {
|
|
||||||
std::cout << "Parse " << name_ << " Error for option " << *arg << std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
*int32_default_ = parsed_value;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Option::ParseBool(std::string *const arg) {
|
|
||||||
if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) {
|
|
||||||
if (*arg == "true") {
|
|
||||||
*bool_default_ = true;
|
|
||||||
} else if (*arg == "false") {
|
|
||||||
*bool_default_ = false;
|
|
||||||
} else {
|
|
||||||
std::cout << "Parse " << name_ << " Error for option " << *arg << std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Option::ParseString(std::string *const arg) {
|
|
||||||
if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) {
|
|
||||||
*string_default_ = *arg;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Option::ParseFloat(std::string *const arg) {
|
|
||||||
if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) {
|
|
||||||
float parsed_value;
|
|
||||||
try {
|
|
||||||
parsed_value = std::stof(arg->data());
|
|
||||||
} catch (std::invalid_argument) {
|
|
||||||
std::cout << "Parse " << name_ << " Error for option " << *arg << std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
*float_default_ = parsed_value;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
Option::Option(const std::string &name, int32_t *const default_point, const std::string &usage)
|
|
||||||
: name_(name),
|
|
||||||
type_(MS_TYPE_INT32),
|
|
||||||
int32_default_(default_point),
|
|
||||||
bool_default_(nullptr),
|
|
||||||
string_default_(nullptr),
|
|
||||||
float_default_(nullptr),
|
|
||||||
usage_(usage) {}
|
|
||||||
|
|
||||||
Option::Option(const std::string &name, bool *const default_point, const std::string &usage)
|
|
||||||
: name_(name),
|
|
||||||
type_(MS_TYPE_BOOL),
|
|
||||||
int32_default_(nullptr),
|
|
||||||
bool_default_(default_point),
|
|
||||||
string_default_(nullptr),
|
|
||||||
float_default_(nullptr),
|
|
||||||
usage_(usage) {}
|
|
||||||
|
|
||||||
Option::Option(const std::string &name, std::string *const default_point, const std::string &usage)
|
|
||||||
: name_(name),
|
|
||||||
type_(MS_TYPE_STRING),
|
|
||||||
int32_default_(nullptr),
|
|
||||||
bool_default_(nullptr),
|
|
||||||
string_default_(default_point),
|
|
||||||
float_default_(nullptr),
|
|
||||||
usage_(usage) {}
|
|
||||||
|
|
||||||
Option::Option(const std::string &name, float *const default_point, const std::string &usage)
|
|
||||||
: name_(name),
|
|
||||||
type_(MS_TYPE_FLOAT),
|
|
||||||
int32_default_(nullptr),
|
|
||||||
bool_default_(nullptr),
|
|
||||||
string_default_(nullptr),
|
|
||||||
float_default_(default_point),
|
|
||||||
usage_(usage) {}
|
|
||||||
|
|
||||||
bool Option::Parse(std::string *const arg) {
|
|
||||||
bool result = false;
|
|
||||||
switch (type_) {
|
|
||||||
case MS_TYPE_BOOL:
|
|
||||||
result = ParseBool(arg);
|
|
||||||
break;
|
|
||||||
case MS_TYPE_FLOAT:
|
|
||||||
result = ParseFloat(arg);
|
|
||||||
break;
|
|
||||||
case MS_TYPE_INT32:
|
|
||||||
result = ParseInt32(arg);
|
|
||||||
break;
|
|
||||||
case MS_TYPE_STRING:
|
|
||||||
result = ParseString(arg);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<Options> Options::inst_ = nullptr;
|
|
||||||
|
|
||||||
Options &Options::Instance() {
|
|
||||||
static Options instance;
|
|
||||||
return instance;
|
|
||||||
}
|
|
||||||
|
|
||||||
Options::Options() : args_(nullptr) { CreateOptions(); }
|
|
||||||
|
|
||||||
void Options::CreateOptions() {
|
|
||||||
args_ = std::make_shared<Arguments>();
|
|
||||||
std::vector<Option> options = {
|
|
||||||
Option("port", &args_->grpc_port,
|
|
||||||
"[Optional] Port to listen on for gRPC API, default is 5500, range from 1 to 65535"),
|
|
||||||
Option("rest_api_port", &args_->rest_api_port,
|
|
||||||
"[Optional] Port to listen on for RESTful API, default is 5501, range from 1 to 65535"),
|
|
||||||
Option("model_name", &args_->model_name, "[Required] model name "),
|
|
||||||
Option("model_path", &args_->model_path, "[Required] the path of the model files"),
|
|
||||||
Option("device_id", &args_->device_id, "[Optional] the device id, default is 0, range from 0 to 7"),
|
|
||||||
};
|
|
||||||
options_ = options;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Options::CheckOptions() {
|
|
||||||
if (args_->model_name == "" || args_->model_path == "") {
|
|
||||||
std::cout << "Serving Error: model_path and model_name should not be null" << std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (args_->device_type != "Ascend") {
|
|
||||||
std::cout << "Serving Error: device_type only support Ascend right now" << std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (args_->device_id > 7) {
|
|
||||||
std::cout << "Serving Error: the device_id should be in [0~7]" << std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (args_->grpc_port < 1 || args_->grpc_port > 65535) {
|
|
||||||
std::cout << "Serving Error: the port should be in [1~65535]" << std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (args_->rest_api_port < 1 || args_->rest_api_port > 65535) {
|
|
||||||
std::cout << "Serving Error: the rest_api_port should be in [1~65535]" << std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (args_->rest_api_port == args_->grpc_port) {
|
|
||||||
std::cout << "Serving Error: the rest_api_port and grpc port should not be same" << std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Options::ParseCommandLine(int argc, char **argv) {
|
|
||||||
if (argc < 2 || (strcmp(argv[1], "--help") == 0)) {
|
|
||||||
Usage();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
std::vector<std::string> unkown_options;
|
|
||||||
for (int i = 1; i < argc; ++i) {
|
|
||||||
bool found = false;
|
|
||||||
for (auto &option : options_) {
|
|
||||||
std::string arg = argv[i];
|
|
||||||
if (option.Parse(&arg)) {
|
|
||||||
found = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (found == false) {
|
|
||||||
unkown_options.push_back(argv[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!unkown_options.empty()) {
|
|
||||||
std::cout << "unkown options:" << std::endl;
|
|
||||||
for (const auto &option : unkown_options) {
|
|
||||||
std::cout << option << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
bool valid = (unkown_options.empty() && CheckOptions());
|
|
||||||
if (!valid) {
|
|
||||||
Usage();
|
|
||||||
}
|
|
||||||
return valid;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Options::Usage() {
|
|
||||||
std::cout << "USAGE: mindspore-serving [options]" << std::endl;
|
|
||||||
|
|
||||||
for (const auto &option : options_) {
|
|
||||||
std::string type;
|
|
||||||
switch (option.type_) {
|
|
||||||
case Option::MS_TYPE_BOOL:
|
|
||||||
type = "bool";
|
|
||||||
break;
|
|
||||||
case Option::MS_TYPE_FLOAT:
|
|
||||||
type = "float";
|
|
||||||
break;
|
|
||||||
case Option::MS_TYPE_INT32:
|
|
||||||
type = "int32";
|
|
||||||
break;
|
|
||||||
case Option::MS_TYPE_STRING:
|
|
||||||
type = "string";
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
std::cout << "--" << std::setw(30) << std::left << option.name_ << std::setw(10) << std::left << type
|
|
||||||
<< option.usage_ << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace serving
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,84 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_SERVING_OPTION_PARSER_H_
|
|
||||||
#define MINDSPORE_SERVING_OPTION_PARSER_H_
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace serving {
|
|
||||||
struct Arguments {
|
|
||||||
int32_t grpc_port = 5500;
|
|
||||||
int32_t rest_api_port = 5501;
|
|
||||||
std::string grpc_socket_path;
|
|
||||||
std::string ssl_config_file;
|
|
||||||
int32_t poll_model_wait_seconds = 1;
|
|
||||||
std::string model_name;
|
|
||||||
std::string model_path;
|
|
||||||
std::string device_type = "Ascend";
|
|
||||||
int32_t device_id = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
class Option {
|
|
||||||
public:
|
|
||||||
Option(const std::string &name, int32_t *default_point, const std::string &usage);
|
|
||||||
Option(const std::string &name, bool *default_point, const std::string &usage);
|
|
||||||
Option(const std::string &name, std::string *default_point, const std::string &usage);
|
|
||||||
Option(const std::string &name, float *default_point, const std::string &usage);
|
|
||||||
~Option() = default;
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend class Options;
|
|
||||||
|
|
||||||
bool ParseInt32(std::string *arg);
|
|
||||||
bool ParseBool(std::string *arg);
|
|
||||||
bool ParseString(std::string *arg);
|
|
||||||
bool ParseFloat(std::string *arg);
|
|
||||||
bool Parse(std::string *arg);
|
|
||||||
std::string name_;
|
|
||||||
enum { MS_TYPE_INT32, MS_TYPE_BOOL, MS_TYPE_STRING, MS_TYPE_FLOAT } type_;
|
|
||||||
int32_t *int32_default_;
|
|
||||||
bool *bool_default_;
|
|
||||||
std::string *string_default_;
|
|
||||||
float *float_default_;
|
|
||||||
std::string usage_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class Options {
|
|
||||||
public:
|
|
||||||
~Options() = default;
|
|
||||||
Options(const Options &) = delete;
|
|
||||||
Options &operator=(const Options &) = delete;
|
|
||||||
static Options &Instance();
|
|
||||||
bool ParseCommandLine(int argc, char **argv);
|
|
||||||
void Usage();
|
|
||||||
std::shared_ptr<Arguments> GetArgs() { return args_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
Options();
|
|
||||||
void CreateOptions();
|
|
||||||
bool CheckOptions();
|
|
||||||
static std::shared_ptr<Options> inst_;
|
|
||||||
std::string usage_;
|
|
||||||
std::vector<Option> options_;
|
|
||||||
std::shared_ptr<Arguments> args_;
|
|
||||||
};
|
|
||||||
} // namespace serving
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif
|
|
|
@ -1,29 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_STATUS_H
|
|
||||||
#define MINDSPORE_STATUS_H
|
|
||||||
#include "include/inference.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace serving {
|
|
||||||
using inference::FAILED;
|
|
||||||
using inference::INVALID_INPUTS;
|
|
||||||
using inference::Status;
|
|
||||||
using inference::SUCCESS;
|
|
||||||
} // namespace serving
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_STATUS_H
|
|
|
@ -1,32 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "core/version_control/model.h"
|
|
||||||
#include <string>
|
|
||||||
#include "include/infer_log.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace serving {
|
|
||||||
MindSporeModel::MindSporeModel(const std::string &model_name, const std::string &model_path,
|
|
||||||
const std::string &model_version, const time_t &last_update_time)
|
|
||||||
: model_name_(model_name),
|
|
||||||
model_path_(model_path),
|
|
||||||
model_version_(model_version),
|
|
||||||
last_update_time_(last_update_time) {
|
|
||||||
MSI_LOG(INFO) << "init mindspore model, model_name = " << model_name_ << ", model_path = " << model_path_
|
|
||||||
<< ", model_version = " << model_version_ << ", last_update_time = " << last_update_time_;
|
|
||||||
}
|
|
||||||
} // namespace serving
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,47 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_SERVING_MODEL_H_
|
|
||||||
#define MINDSPORE_SERVING_MODEL_H_
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <ctime>
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace serving {
|
|
||||||
class MindSporeModel {
|
|
||||||
public:
|
|
||||||
MindSporeModel(const std::string &model_name, const std::string &model_path, const std::string &model_version,
|
|
||||||
const time_t &last_update_time);
|
|
||||||
~MindSporeModel() = default;
|
|
||||||
std::string GetModelName() { return model_name_; }
|
|
||||||
std::string GetModelPath() { return model_path_; }
|
|
||||||
std::string GetModelVersion() { return model_version_; }
|
|
||||||
time_t GetLastUpdateTime() { return last_update_time_; }
|
|
||||||
void SetLastUpdateTime(const time_t &last_update_time) { last_update_time_ = last_update_time; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::string model_name_;
|
|
||||||
std::string model_path_;
|
|
||||||
std::string model_version_;
|
|
||||||
time_t last_update_time_;
|
|
||||||
};
|
|
||||||
|
|
||||||
using MindSporeModelPtr = std::shared_ptr<MindSporeModel>;
|
|
||||||
} // namespace serving
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // !MINDSPORE_SERVING_MODEL_H_
|
|
|
@ -1,130 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "core/version_control/version_controller.h"
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <iostream>
|
|
||||||
#include <ctime>
|
|
||||||
#include <memory>
|
|
||||||
#include "util/file_system_operation.h"
|
|
||||||
#include "include/infer_log.h"
|
|
||||||
#include "core/session.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace serving {
|
|
||||||
volatile bool stop_poll = false;
|
|
||||||
|
|
||||||
std::string GetVersionFromPath(const std::string &path) {
|
|
||||||
std::string new_path = path;
|
|
||||||
while (!new_path.empty() && new_path.back() == '/') {
|
|
||||||
new_path = new_path.substr(0, new_path.size() - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string::size_type index = new_path.find_last_of("/");
|
|
||||||
std::string version = new_path.substr(index + 1);
|
|
||||||
return version;
|
|
||||||
}
|
|
||||||
|
|
||||||
void PeriodicFunction::operator()() {
|
|
||||||
while (true) {
|
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(poll_model_wait_seconds_ * 1000));
|
|
||||||
std::vector<std::string> SubDirs = GetAllSubDirs(models_path_);
|
|
||||||
|
|
||||||
if (version_control_strategy_ == VersionController::VersionControllerStrategy::kLastest) {
|
|
||||||
auto path = SubDirs.empty() ? models_path_ : SubDirs.back();
|
|
||||||
std::string model_version = GetVersionFromPath(path);
|
|
||||||
time_t last_update_time = GetModifyTime(path);
|
|
||||||
if (model_version != valid_models_.back()->GetModelVersion()) {
|
|
||||||
MindSporeModelPtr model_ptr = std::make_shared<MindSporeModel>(valid_models_.front()->GetModelName(), path,
|
|
||||||
model_version, last_update_time);
|
|
||||||
valid_models_.back() = model_ptr;
|
|
||||||
Session::Instance().Warmup(valid_models_.back());
|
|
||||||
} else {
|
|
||||||
if (difftime(valid_models_.back()->GetLastUpdateTime(), last_update_time) < 0) {
|
|
||||||
valid_models_.back()->SetLastUpdateTime(last_update_time);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// not support
|
|
||||||
}
|
|
||||||
|
|
||||||
if (stop_poll == true) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
VersionController::VersionController(int32_t poll_model_wait_seconds, const std::string &models_path,
|
|
||||||
const std::string &model_name)
|
|
||||||
: version_control_strategy_(kLastest),
|
|
||||||
poll_model_wait_seconds_(poll_model_wait_seconds),
|
|
||||||
models_path_(models_path),
|
|
||||||
model_name_(model_name) {}
|
|
||||||
|
|
||||||
void StopPollModelPeriodic() { stop_poll = true; }
|
|
||||||
|
|
||||||
VersionController::~VersionController() {
|
|
||||||
StopPollModelPeriodic();
|
|
||||||
if (poll_model_thread_.joinable()) {
|
|
||||||
poll_model_thread_.join();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Status VersionController::Run() {
|
|
||||||
Status ret;
|
|
||||||
ret = CreateInitModels();
|
|
||||||
if (ret != SUCCESS) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
return SUCCESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status VersionController::CreateInitModels() {
|
|
||||||
if (!DirOrFileExist(models_path_)) {
|
|
||||||
MSI_LOG(ERROR) << "Model Path Not Exist!" << std::endl;
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
std::vector<std::string> SubDirs = GetAllSubDirs(models_path_);
|
|
||||||
if (version_control_strategy_ == kLastest) {
|
|
||||||
std::string model_version = GetVersionFromPath(models_path_);
|
|
||||||
time_t last_update_time = GetModifyTime(models_path_);
|
|
||||||
MindSporeModelPtr model_ptr =
|
|
||||||
std::make_shared<MindSporeModel>(model_name_, models_path_, model_version, last_update_time);
|
|
||||||
valid_models_.emplace_back(model_ptr);
|
|
||||||
} else {
|
|
||||||
for (auto &dir : SubDirs) {
|
|
||||||
std::string model_version = GetVersionFromPath(dir);
|
|
||||||
time_t last_update_time = GetModifyTime(dir);
|
|
||||||
MindSporeModelPtr model_ptr = std::make_shared<MindSporeModel>(model_name_, dir, model_version, last_update_time);
|
|
||||||
valid_models_.emplace_back(model_ptr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (valid_models_.empty()) {
|
|
||||||
MSI_LOG(ERROR) << "There is no valid model for serving";
|
|
||||||
return FAILED;
|
|
||||||
}
|
|
||||||
auto ret = Session::Instance().Warmup(valid_models_.back());
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
void VersionController::StartPollModelPeriodic() {
|
|
||||||
poll_model_thread_ = std::thread(
|
|
||||||
PeriodicFunction(poll_model_wait_seconds_, models_path_, version_control_strategy_, std::ref(valid_models_)));
|
|
||||||
}
|
|
||||||
|
|
||||||
void VersionController::StopPollModelPeriodic() {}
|
|
||||||
} // namespace serving
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,70 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_SERVING_VERSOIN_CONTROLLER_H_
|
|
||||||
#define MINDSPORE_SERVING_VERSOIN_CONTROLLER_H_
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include <thread>
|
|
||||||
#include "./model.h"
|
|
||||||
#include "util/status.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace serving {
|
|
||||||
class VersionController {
|
|
||||||
public:
|
|
||||||
enum VersionControllerStrategy { kLastest = 0, kMulti = 1 };
|
|
||||||
|
|
||||||
VersionController(int32_t poll_model_wait_seconds, const std::string &models_path, const std::string &model_name);
|
|
||||||
~VersionController();
|
|
||||||
Status Run();
|
|
||||||
void StartPollModelPeriodic();
|
|
||||||
void StopPollModelPeriodic();
|
|
||||||
|
|
||||||
private:
|
|
||||||
Status CreateInitModels();
|
|
||||||
|
|
||||||
private:
|
|
||||||
VersionControllerStrategy version_control_strategy_;
|
|
||||||
std::vector<MindSporeModelPtr> valid_models_;
|
|
||||||
int32_t poll_model_wait_seconds_;
|
|
||||||
std::thread poll_model_thread_;
|
|
||||||
std::string models_path_;
|
|
||||||
std::string model_name_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class PeriodicFunction {
|
|
||||||
public:
|
|
||||||
PeriodicFunction(int32_t poll_model_wait_seconds, const std::string &models_path,
|
|
||||||
VersionController::VersionControllerStrategy version_control_strategy,
|
|
||||||
const std::vector<MindSporeModelPtr> &valid_models)
|
|
||||||
: poll_model_wait_seconds_(poll_model_wait_seconds),
|
|
||||||
models_path_(models_path),
|
|
||||||
version_control_strategy_(version_control_strategy),
|
|
||||||
valid_models_(valid_models) {}
|
|
||||||
~PeriodicFunction() = default;
|
|
||||||
void operator()();
|
|
||||||
|
|
||||||
private:
|
|
||||||
int32_t poll_model_wait_seconds_;
|
|
||||||
std::string models_path_;
|
|
||||||
VersionController::VersionControllerStrategy version_control_strategy_;
|
|
||||||
std::vector<MindSporeModelPtr> valid_models_;
|
|
||||||
};
|
|
||||||
} // namespace serving
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // !MINDSPORE_SERVING_VERSOIN_CONTROLLER_H_
|
|
|
@ -1,80 +0,0 @@
|
||||||
cmake_minimum_required(VERSION 3.5.1)
|
|
||||||
|
|
||||||
project(MSClient C CXX)
|
|
||||||
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
|
|
||||||
|
|
||||||
find_package(Threads REQUIRED)
|
|
||||||
|
|
||||||
# This branch assumes that gRPC and all its dependencies are already installed
|
|
||||||
# on this system, so they can be located by find_package().
|
|
||||||
|
|
||||||
# Find Protobuf installation
|
|
||||||
# Looks for protobuf-config.cmake file installed by Protobuf's cmake installation.
|
|
||||||
option(GRPC_PATH "set grpc path")
|
|
||||||
if(GRPC_PATH)
|
|
||||||
set(CMAKE_PREFIX_PATH ${GRPC_PATH})
|
|
||||||
set(protobuf_MODULE_COMPATIBLE TRUE)
|
|
||||||
find_package(Protobuf CONFIG REQUIRED)
|
|
||||||
message(STATUS "Using protobuf ${protobuf_VERSION}, CMAKE_PREFIX_PATH : ${CMAKE_PREFIX_PATH}")
|
|
||||||
elseif(NOT GRPC_PATH AND grpc_ROOT)
|
|
||||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
|
||||||
if (EXISTS ${grpc_ROOT}/lib64)
|
|
||||||
set(gRPC_DIR "${grpc_ROOT}/lib64/cmake/grpc")
|
|
||||||
elseif(EXISTS ${grpc_ROOT}/lib)
|
|
||||||
set(gRPC_DIR "${grpc_ROOT}/lib/cmake/grpc")
|
|
||||||
endif()
|
|
||||||
add_library(protobuf::libprotobuf ALIAS protobuf::protobuf)
|
|
||||||
add_executable(protobuf::libprotoc ALIAS protobuf::protoc)
|
|
||||||
message(STATUS "serving using grpc_DIR : " ${gRPC_DIR})
|
|
||||||
elseif(NOT gRPC_DIR AND NOT GRPC_PATH)
|
|
||||||
message(FATAL_ERROR "please check gRPC. If the client is compiled separately,you can use the command: cmake -D GRPC_PATH=xxx\n" "XXX is the gRPC installation path")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(CMAKE_CROSSCOMPILING)
|
|
||||||
find_program(_PROTOBUF_PROTOC protoc)
|
|
||||||
else()
|
|
||||||
set(_PROTOBUF_PROTOC $<TARGET_FILE:protobuf::protoc>)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Find gRPC installation
|
|
||||||
# Looks for gRPCConfig.cmake file installed by gRPC's cmake installation.
|
|
||||||
find_package(gRPC CONFIG REQUIRED)
|
|
||||||
message(STATUS "Using gRPC ${gRPC_VERSION}")
|
|
||||||
|
|
||||||
if(CMAKE_CROSSCOMPILING)
|
|
||||||
find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin)
|
|
||||||
else()
|
|
||||||
set(_GRPC_CPP_PLUGIN_EXECUTABLE $<TARGET_FILE:gRPC::grpc_cpp_plugin>)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Proto file
|
|
||||||
get_filename_component(hw_proto "../../ms_service.proto" ABSOLUTE)
|
|
||||||
get_filename_component(hw_proto_path "${hw_proto}" PATH)
|
|
||||||
|
|
||||||
# Generated sources
|
|
||||||
set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.cc")
|
|
||||||
set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.h")
|
|
||||||
set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.cc")
|
|
||||||
set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.h")
|
|
||||||
add_custom_command(
|
|
||||||
OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}"
|
|
||||||
COMMAND ${_PROTOBUF_PROTOC}
|
|
||||||
ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}"
|
|
||||||
--cpp_out "${CMAKE_CURRENT_BINARY_DIR}"
|
|
||||||
-I "${hw_proto_path}"
|
|
||||||
--plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}"
|
|
||||||
"${hw_proto}"
|
|
||||||
DEPENDS "${hw_proto}")
|
|
||||||
|
|
||||||
# Include generated *.pb.h files
|
|
||||||
include_directories("${CMAKE_CURRENT_BINARY_DIR}")
|
|
||||||
|
|
||||||
# Targets greeter_[async_](client|server)
|
|
||||||
add_executable(ms_client "ms_client.cc"
|
|
||||||
${hw_proto_srcs}
|
|
||||||
${hw_grpc_srcs})
|
|
||||||
target_link_libraries(ms_client
|
|
||||||
gRPC::grpc++_reflection
|
|
||||||
gRPC::grpc++
|
|
||||||
protobuf::libprotobuf)
|
|
|
@ -1,113 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include <grpcpp/grpcpp.h>
|
|
||||||
#include <iostream>
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <fstream>
|
|
||||||
#include "./ms_service.grpc.pb.h"
|
|
||||||
|
|
||||||
using grpc::Channel;
|
|
||||||
using grpc::ClientContext;
|
|
||||||
using grpc::Status;
|
|
||||||
using ms_serving::MSService;
|
|
||||||
using ms_serving::PredictReply;
|
|
||||||
using ms_serving::PredictRequest;
|
|
||||||
using ms_serving::Tensor;
|
|
||||||
using ms_serving::TensorShape;
|
|
||||||
|
|
||||||
class MSClient {
|
|
||||||
public:
|
|
||||||
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
|
|
||||||
|
|
||||||
~MSClient() = default;
|
|
||||||
|
|
||||||
std::string Predict() {
|
|
||||||
// Data we are sending to the server.
|
|
||||||
PredictRequest request;
|
|
||||||
|
|
||||||
Tensor data;
|
|
||||||
TensorShape shape;
|
|
||||||
shape.add_dims(2);
|
|
||||||
shape.add_dims(2);
|
|
||||||
*data.mutable_tensor_shape() = shape;
|
|
||||||
data.set_tensor_type(ms_serving::MS_FLOAT32);
|
|
||||||
std::vector<float> input_data{1, 2, 3, 4};
|
|
||||||
data.set_data(input_data.data(), input_data.size() * sizeof(float));
|
|
||||||
*request.add_data() = data;
|
|
||||||
*request.add_data() = data;
|
|
||||||
std::cout << "intput tensor size is " << request.data_size() << std::endl;
|
|
||||||
// Container for the data we expect from the server.
|
|
||||||
PredictReply reply;
|
|
||||||
|
|
||||||
// Context for the client. It could be used to convey extra information to
|
|
||||||
// the server and/or tweak certain RPC behaviors.
|
|
||||||
ClientContext context;
|
|
||||||
|
|
||||||
// The actual RPC.
|
|
||||||
Status status = stub_->Predict(&context, request, &reply);
|
|
||||||
std::cout << "Compute [[1, 2], [3, 4]] + [[1, 2], [3, 4]]" << std::endl;
|
|
||||||
|
|
||||||
// Act upon its status.
|
|
||||||
if (status.ok()) {
|
|
||||||
std::cout << "Add result is";
|
|
||||||
for (size_t i = 0; i < reply.result(0).data().size() / sizeof(float); i++) {
|
|
||||||
std::cout << " " << (reinterpret_cast<const float *>(reply.mutable_result(0)->mutable_data()->data()))[i];
|
|
||||||
}
|
|
||||||
std::cout << std::endl;
|
|
||||||
return "RPC OK";
|
|
||||||
} else {
|
|
||||||
std::cout << status.error_code() << ": " << status.error_message() << std::endl;
|
|
||||||
return "RPC failed";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::unique_ptr<MSService::Stub> stub_;
|
|
||||||
};
|
|
||||||
|
|
||||||
int main(int argc, char **argv) {
|
|
||||||
// Instantiate the client. It requires a channel, out of which the actual RPCs
|
|
||||||
// are created. This channel models a connection to an endpoint specified by
|
|
||||||
// the argument "--target=" which is the only expected argument.
|
|
||||||
// We indicate that the channel isn't authenticated (use of
|
|
||||||
// InsecureChannelCredentials()).
|
|
||||||
std::string target_str;
|
|
||||||
std::string arg_target_str("--target");
|
|
||||||
if (argc > 1) {
|
|
||||||
// parse target
|
|
||||||
std::string arg_val = argv[1];
|
|
||||||
size_t start_pos = arg_val.find(arg_target_str);
|
|
||||||
if (start_pos != std::string::npos) {
|
|
||||||
start_pos += arg_target_str.size();
|
|
||||||
if (start_pos < arg_val.size() && arg_val[start_pos] == '=') {
|
|
||||||
target_str = arg_val.substr(start_pos + 1);
|
|
||||||
} else {
|
|
||||||
std::cout << "The only correct argument syntax is --target=" << std::endl;
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
target_str = "localhost:5500";
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
target_str = "localhost:5500";
|
|
||||||
}
|
|
||||||
MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()));
|
|
||||||
std::string reply = client.Predict();
|
|
||||||
std::cout << "client received: " << reply << std::endl;
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
|
@ -1,45 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import mindspore.context as context
|
|
||||||
import mindspore.nn as nn
|
|
||||||
from mindspore.ops import operations as P
|
|
||||||
from mindspore import Tensor
|
|
||||||
from mindspore.train.serialization import export
|
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
self.add = P.TensorAdd()
|
|
||||||
|
|
||||||
def construct(self, x_, y_):
|
|
||||||
return self.add(x_, y_)
|
|
||||||
|
|
||||||
def export_net():
|
|
||||||
x = np.ones([2, 2]).astype(np.float32)
|
|
||||||
y = np.ones([2, 2]).astype(np.float32)
|
|
||||||
add = Net()
|
|
||||||
output = add(Tensor(x), Tensor(y))
|
|
||||||
export(add, Tensor(x), Tensor(y), file_name='tensor_add', file_format='MINDIR')
|
|
||||||
print(x)
|
|
||||||
print(y)
|
|
||||||
print(output.asnumpy())
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
export_net()
|
|
||||||
|
|
|
@ -1,61 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
import sys
|
|
||||||
import grpc
|
|
||||||
import numpy as np
|
|
||||||
import ms_service_pb2
|
|
||||||
import ms_service_pb2_grpc
|
|
||||||
|
|
||||||
|
|
||||||
def run():
|
|
||||||
if len(sys.argv) > 2:
|
|
||||||
sys.exit("input error")
|
|
||||||
channel_str = ""
|
|
||||||
if len(sys.argv) == 2:
|
|
||||||
split_args = sys.argv[1].split('=')
|
|
||||||
if len(split_args) > 1:
|
|
||||||
channel_str = split_args[1]
|
|
||||||
else:
|
|
||||||
channel_str = 'localhost:5500'
|
|
||||||
else:
|
|
||||||
channel_str = 'localhost:5500'
|
|
||||||
|
|
||||||
channel = grpc.insecure_channel(channel_str)
|
|
||||||
stub = ms_service_pb2_grpc.MSServiceStub(channel)
|
|
||||||
request = ms_service_pb2.PredictRequest()
|
|
||||||
|
|
||||||
x = request.data.add()
|
|
||||||
x.tensor_shape.dims.extend([2, 2])
|
|
||||||
x.tensor_type = ms_service_pb2.MS_FLOAT32
|
|
||||||
x.data = (np.ones([2, 2]).astype(np.float32)).tobytes()
|
|
||||||
|
|
||||||
y = request.data.add()
|
|
||||||
y.tensor_shape.dims.extend([2, 2])
|
|
||||||
y.tensor_type = ms_service_pb2.MS_FLOAT32
|
|
||||||
y.data = (np.ones([2, 2]).astype(np.float32)).tobytes()
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = stub.Predict(request)
|
|
||||||
result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims)
|
|
||||||
print("ms client received: ")
|
|
||||||
print(result_np)
|
|
||||||
except grpc.RpcError as e:
|
|
||||||
print(e.details())
|
|
||||||
status_code = e.code()
|
|
||||||
print(status_code.name)
|
|
||||||
print(status_code.value)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
run()
|
|
|
@ -1,29 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "core/server.h"
|
|
||||||
#include "core/util/option_parser.h"
|
|
||||||
|
|
||||||
using mindspore::serving::Options;
|
|
||||||
|
|
||||||
int main(int argc, char **argv) {
|
|
||||||
auto flag = Options::Instance().ParseCommandLine(argc, argv);
|
|
||||||
if (!flag) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
mindspore::serving::Server server;
|
|
||||||
server.BuildAndStart();
|
|
||||||
return 0;
|
|
||||||
}
|
|
|
@ -1,70 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
// ms_service.proto
|
|
||||||
syntax = "proto3";
|
|
||||||
|
|
||||||
package ms_serving;
|
|
||||||
|
|
||||||
service MSService {
|
|
||||||
rpc Predict(PredictRequest) returns (PredictReply) {}
|
|
||||||
rpc Test(PredictRequest) returns (PredictReply) {}
|
|
||||||
}
|
|
||||||
|
|
||||||
message PredictRequest {
|
|
||||||
repeated Tensor data = 1;
|
|
||||||
repeated Images images = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message PredictReply {
|
|
||||||
repeated Tensor result = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
enum DataType {
|
|
||||||
MS_UNKNOWN = 0;
|
|
||||||
MS_BOOL = 1;
|
|
||||||
MS_INT8 = 2;
|
|
||||||
MS_UINT8 = 3;
|
|
||||||
MS_INT16 = 4;
|
|
||||||
MS_UINT16 = 5;
|
|
||||||
MS_INT32 = 6;
|
|
||||||
MS_UINT32 = 7;
|
|
||||||
MS_INT64 = 8;
|
|
||||||
MS_UINT64 = 9;
|
|
||||||
MS_FLOAT16 = 10;
|
|
||||||
MS_FLOAT32 = 11;
|
|
||||||
MS_FLOAT64 = 12;
|
|
||||||
}
|
|
||||||
|
|
||||||
message TensorShape {
|
|
||||||
repeated int64 dims = 1;
|
|
||||||
};
|
|
||||||
|
|
||||||
message Tensor {
|
|
||||||
// tensor shape info
|
|
||||||
TensorShape tensor_shape = 1;
|
|
||||||
|
|
||||||
// tensor content data type
|
|
||||||
DataType tensor_type = 2;
|
|
||||||
|
|
||||||
// tensor data
|
|
||||||
bytes data = 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
message Images{
|
|
||||||
repeated bytes images = 1;
|
|
||||||
uint32 input_index = 2;
|
|
||||||
}
|
|
|
@ -1,105 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
CLANG_FORMAT=$(which clang-format) || (echo "Please install 'clang-format' tool first"; exit 1)
|
|
||||||
|
|
||||||
version=$("${CLANG_FORMAT}" --version | sed -n "s/.*\ \([0-9]*\)\.[0-9]*\.[0-9]*.*/\1/p")
|
|
||||||
if [[ "${version}" -lt "8" ]]; then
|
|
||||||
echo "clang-format's version must be at least 8.0.0"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
CURRENT_PATH=$(pwd)
|
|
||||||
SCRIPTS_PATH=$(dirname "$0")
|
|
||||||
|
|
||||||
echo "CURRENT_PATH=${CURRENT_PATH}"
|
|
||||||
echo "SCRIPTS_PATH=${SCRIPTS_PATH}"
|
|
||||||
|
|
||||||
# print usage message
|
|
||||||
function usage()
|
|
||||||
{
|
|
||||||
echo "Format the specified source files to conform the code style."
|
|
||||||
echo "Usage:"
|
|
||||||
echo "bash $0 [-a] [-c] [-l] [-h]"
|
|
||||||
echo "e.g. $0 -c"
|
|
||||||
echo ""
|
|
||||||
echo "Options:"
|
|
||||||
echo " -a format of all files"
|
|
||||||
echo " -c format of the files changed compared to last commit, default case"
|
|
||||||
echo " -l format of the files changed in last commit"
|
|
||||||
echo " -h Print usage"
|
|
||||||
}
|
|
||||||
|
|
||||||
# check and set options
|
|
||||||
function checkopts()
|
|
||||||
{
|
|
||||||
# init variable
|
|
||||||
mode="changed" # default format changed files
|
|
||||||
|
|
||||||
# Process the options
|
|
||||||
while getopts 'aclh' opt
|
|
||||||
do
|
|
||||||
case "${opt}" in
|
|
||||||
a)
|
|
||||||
mode="all"
|
|
||||||
;;
|
|
||||||
c)
|
|
||||||
mode="changed"
|
|
||||||
;;
|
|
||||||
l)
|
|
||||||
mode="lastcommit"
|
|
||||||
;;
|
|
||||||
h)
|
|
||||||
usage
|
|
||||||
exit 0
|
|
||||||
;;
|
|
||||||
*)
|
|
||||||
echo "Unknown option ${opt}!"
|
|
||||||
usage
|
|
||||||
exit 1
|
|
||||||
esac
|
|
||||||
done
|
|
||||||
}
|
|
||||||
|
|
||||||
# init variable
|
|
||||||
# check options
|
|
||||||
checkopts "$@"
|
|
||||||
|
|
||||||
# switch to project root path, which contains clang-format config file '.clang-format'
|
|
||||||
cd "${SCRIPTS_PATH}/../.." || exit 1
|
|
||||||
|
|
||||||
FMT_FILE_LIST='__format_files_list__'
|
|
||||||
|
|
||||||
if [[ "X${mode}" == "Xall" ]]; then
|
|
||||||
find ./ -type f -name "*" | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true
|
|
||||||
elif [[ "X${mode}" == "Xchanged" ]]; then
|
|
||||||
git diff --name-only | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true
|
|
||||||
else # "X${mode}" == "Xlastcommit"
|
|
||||||
git diff --name-only HEAD~ HEAD | grep "\.h$\|\.cc$" > "${FMT_FILE_LIST}" || true
|
|
||||||
fi
|
|
||||||
|
|
||||||
while read line; do
|
|
||||||
if [ -f "${line}" ]; then
|
|
||||||
${CLANG_FORMAT} -i "${line}"
|
|
||||||
fi
|
|
||||||
done < "${FMT_FILE_LIST}"
|
|
||||||
|
|
||||||
rm "${FMT_FILE_LIST}"
|
|
||||||
cd "${CURRENT_PATH}" || exit 1
|
|
||||||
|
|
||||||
echo "Specified cpp source files have been format successfully."
|
|
3
setup.py
3
setup.py
|
@ -132,7 +132,6 @@ package_data = {
|
||||||
'lib/*.so*',
|
'lib/*.so*',
|
||||||
'lib/*.a',
|
'lib/*.a',
|
||||||
'.commit_id',
|
'.commit_id',
|
||||||
'ms_serving',
|
|
||||||
'config/*'
|
'config/*'
|
||||||
'include/*',
|
'include/*',
|
||||||
'include/*/*',
|
'include/*/*',
|
||||||
|
@ -160,8 +159,6 @@ def update_permissions(path):
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
file_fullpath = os.path.join(dirpath, filename)
|
file_fullpath = os.path.join(dirpath, filename)
|
||||||
os.chmod(file_fullpath, stat.S_IREAD)
|
os.chmod(file_fullpath, stat.S_IREAD)
|
||||||
if filename == "ms_serving":
|
|
||||||
os.chmod(file_fullpath, stat.S_IREAD | stat.S_IEXEC)
|
|
||||||
|
|
||||||
def bin_files():
|
def bin_files():
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,98 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
import random
|
|
||||||
import json
|
|
||||||
import grpc
|
|
||||||
import numpy as np
|
|
||||||
import requests
|
|
||||||
import ms_service_pb2
|
|
||||||
import ms_service_pb2_grpc
|
|
||||||
import mindspore.dataset as de
|
|
||||||
from mindspore import Tensor, context
|
|
||||||
from mindspore import log as logger
|
|
||||||
from tests.st.networks.models.bert.src.bert_model import BertModel
|
|
||||||
from .generate_model import bert_net_cfg
|
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
|
||||||
|
|
||||||
random.seed(1)
|
|
||||||
np.random.seed(1)
|
|
||||||
de.config.set_seed(1)
|
|
||||||
|
|
||||||
def test_bert():
|
|
||||||
MAX_MESSAGE_LENGTH = 0x7fffffff
|
|
||||||
input_ids = np.random.randint(0, 1000, size=(2, 32), dtype=np.int32)
|
|
||||||
segment_ids = np.zeros((2, 32), dtype=np.int32)
|
|
||||||
input_mask = np.zeros((2, 32), dtype=np.int32)
|
|
||||||
|
|
||||||
# grpc visit
|
|
||||||
channel = grpc.insecure_channel('localhost:5500', options=[('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
|
|
||||||
('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH)])
|
|
||||||
stub = ms_service_pb2_grpc.MSServiceStub(channel)
|
|
||||||
request = ms_service_pb2.PredictRequest()
|
|
||||||
|
|
||||||
x = request.data.add()
|
|
||||||
x.tensor_shape.dims.extend([2, 32])
|
|
||||||
x.tensor_type = ms_service_pb2.MS_INT32
|
|
||||||
x.data = input_ids.tobytes()
|
|
||||||
|
|
||||||
y = request.data.add()
|
|
||||||
y.tensor_shape.dims.extend([2, 32])
|
|
||||||
y.tensor_type = ms_service_pb2.MS_INT32
|
|
||||||
y.data = segment_ids.tobytes()
|
|
||||||
|
|
||||||
z = request.data.add()
|
|
||||||
z.tensor_shape.dims.extend([2, 32])
|
|
||||||
z.tensor_type = ms_service_pb2.MS_INT32
|
|
||||||
z.data = input_mask.tobytes()
|
|
||||||
|
|
||||||
result = stub.Predict(request)
|
|
||||||
grpc_result = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims)
|
|
||||||
print("ms grpc client received: ")
|
|
||||||
print(grpc_result)
|
|
||||||
|
|
||||||
# ms result
|
|
||||||
net = BertModel(bert_net_cfg, False)
|
|
||||||
bert_out = net(Tensor(input_ids), Tensor(segment_ids), Tensor(input_mask))
|
|
||||||
print("bert out: ")
|
|
||||||
print(bert_out[0])
|
|
||||||
bert_out_size = len(bert_out)
|
|
||||||
|
|
||||||
# compare grpc result
|
|
||||||
for i in range(bert_out_size):
|
|
||||||
grpc_result = np.frombuffer(result.result[i].data, dtype=np.float32).reshape(result.result[i].tensor_shape.dims)
|
|
||||||
logger.info("i:{}, grpc_result:{}, bert_out:{}".
|
|
||||||
format(i, result.result[i].tensor_shape.dims, bert_out[i].asnumpy().shape))
|
|
||||||
assert np.allclose(bert_out[i].asnumpy(), grpc_result, 0.001, 0.001, equal_nan=True)
|
|
||||||
|
|
||||||
# http visit
|
|
||||||
data = {"tensor": [input_ids.tolist(), segment_ids.tolist(), input_mask.tolist()]}
|
|
||||||
url = "http://127.0.0.1:5501"
|
|
||||||
input_json = json.dumps(data)
|
|
||||||
headers = {'Content-type': 'application/json'}
|
|
||||||
response = requests.post(url, data=input_json, headers=headers)
|
|
||||||
result = response.text
|
|
||||||
result = result.replace('\r', '\\r').replace('\n', '\\n')
|
|
||||||
result_json = json.loads(result, strict=False)
|
|
||||||
http_result = np.array(result_json['tensor'])
|
|
||||||
print("ms http client received: ")
|
|
||||||
print(http_result[0][:200])
|
|
||||||
|
|
||||||
# compare http result
|
|
||||||
for i in range(bert_out_size):
|
|
||||||
logger.info("i:{}, http_result:{}, bert_out:{}".
|
|
||||||
format(i, np.shape(http_result[i]), bert_out[i].asnumpy().shape))
|
|
||||||
assert np.allclose(bert_out[i].asnumpy(), http_result[i], 0.001, 0.001, equal_nan=True)
|
|
|
@ -1,60 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
import mindspore.common.dtype as mstype
|
|
||||||
import mindspore.dataset as de
|
|
||||||
from mindspore import Tensor, context
|
|
||||||
from mindspore.train.serialization import export
|
|
||||||
from tests.st.networks.models.bert.src.bert_model import BertModel, BertConfig
|
|
||||||
|
|
||||||
bert_net_cfg = BertConfig(
|
|
||||||
batch_size=2,
|
|
||||||
seq_length=32,
|
|
||||||
vocab_size=12,
|
|
||||||
hidden_size=12,
|
|
||||||
num_hidden_layers=12,
|
|
||||||
num_attention_heads=12,
|
|
||||||
intermediate_size=3072,
|
|
||||||
hidden_act="gelu",
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
max_position_embeddings=512,
|
|
||||||
type_vocab_size=2,
|
|
||||||
initializer_range=0.02,
|
|
||||||
use_relative_positions=False,
|
|
||||||
input_mask_from_dataset=True,
|
|
||||||
token_type_ids_from_dataset=True,
|
|
||||||
dtype=mstype.float32,
|
|
||||||
compute_type=mstype.float16
|
|
||||||
)
|
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
|
||||||
|
|
||||||
random.seed(1)
|
|
||||||
np.random.seed(1)
|
|
||||||
de.config.set_seed(1)
|
|
||||||
|
|
||||||
def export_bert_model():
|
|
||||||
input_ids = np.random.randint(0, 1000, size=(2, 32), dtype=np.int32)
|
|
||||||
segment_ids = np.zeros((2, 32), dtype=np.int32)
|
|
||||||
input_mask = np.zeros((2, 32), dtype=np.int32)
|
|
||||||
net = BertModel(bert_net_cfg, False)
|
|
||||||
export(net, Tensor(input_ids), Tensor(segment_ids), Tensor(input_mask),
|
|
||||||
file_name='bert', file_format='MINDIR')
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
export_bert_model()
|
|
|
@ -1,112 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
export GLOG_v=1
|
|
||||||
export DEVICE_ID=1
|
|
||||||
|
|
||||||
MINDSPORE_INSTALL_PATH=$1
|
|
||||||
ENV_DEVICE_ID=$DEVICE_ID
|
|
||||||
CURRPATH=$(cd $(dirname $0); pwd)
|
|
||||||
CURRUSER=$(whoami)
|
|
||||||
PROJECT_PATH=${CURRPATH}/../../../
|
|
||||||
echo "MINDSPORE_INSTALL_PATH:" ${MINDSPORE_INSTALL_PATH}
|
|
||||||
echo "ENV_DEVICE_ID:" ${ENV_DEVICE_ID}
|
|
||||||
echo "CURRPATH:" ${CURRPATH}
|
|
||||||
echo "CURRUSER:" ${CURRUSER}
|
|
||||||
echo "PROJECT_PATH:" ${PROJECT_PATH}
|
|
||||||
|
|
||||||
MODEL_PATH=${CURRPATH}/model
|
|
||||||
export LD_LIBRARY_PATH=${MINDSPORE_INSTALL_PATH}/lib:/usr/local/python/python375/lib/:${LD_LIBRARY_PATH}
|
|
||||||
export PYTHONPATH=${MINDSPORE_INSTALL_PATH}/:${PYTHONPATH}
|
|
||||||
|
|
||||||
echo "LD_LIBRARY_PATH: " ${LD_LIBRARY_PATH}
|
|
||||||
echo "PYTHONPATH: " ${PYTHONPATH}
|
|
||||||
echo "-------------show MINDSPORE_INSTALL_PATH----------------"
|
|
||||||
ls -l ${MINDSPORE_INSTALL_PATH}
|
|
||||||
echo "------------------show /usr/lib64/----------------------"
|
|
||||||
ls -l /usr/local/python/python375/lib/
|
|
||||||
|
|
||||||
clean_pid()
|
|
||||||
{
|
|
||||||
ps aux | grep 'ms_serving' | grep ${CURRUSER} | grep -v grep | awk '{print $2}' | xargs kill -15
|
|
||||||
if [ $? -ne 0 ]
|
|
||||||
then
|
|
||||||
echo "clean pip failed"
|
|
||||||
fi
|
|
||||||
sleep 6
|
|
||||||
}
|
|
||||||
|
|
||||||
prepare_model()
|
|
||||||
{
|
|
||||||
echo "### begin to generate mode for serving test ###"
|
|
||||||
python3 generate_model.py &> generate_model_serving.log
|
|
||||||
echo "### end to generate mode for serving test ###"
|
|
||||||
result=`ls -l | grep -E '*mindir' | grep -v ".log" | wc -l`
|
|
||||||
if [ ${result} -ne 1 ]
|
|
||||||
then
|
|
||||||
cat generate_model_serving.log
|
|
||||||
echo "### generate model for serving test failed ###" && exit 1
|
|
||||||
clean_pid
|
|
||||||
fi
|
|
||||||
rm -rf model
|
|
||||||
mkdir model
|
|
||||||
mv *.mindir ${CURRPATH}/model
|
|
||||||
cp ${MINDSPORE_INSTALL_PATH}/ms_serving ./
|
|
||||||
}
|
|
||||||
|
|
||||||
start_service()
|
|
||||||
{
|
|
||||||
echo "### start serving service ###"
|
|
||||||
${CURRPATH}/ms_serving --port=$1 --model_path=${MODEL_PATH} --model_name=$2 --device_id=$3 > $2_service.log 2>&1 &
|
|
||||||
if [ $? -ne 0 ]
|
|
||||||
then
|
|
||||||
echo "$2 faile to start."
|
|
||||||
fi
|
|
||||||
|
|
||||||
result=`grep -E -A5 -B5 'MS Serving grpc listening on 0.0.0.0:5500' $2_service.log |
|
|
||||||
grep -E 'MS Serving restful listening on 0.0.0.0:5501'|wc -l`
|
|
||||||
count=0
|
|
||||||
while [[ ${result} -ne 1 && ${count} -lt 150 ]]
|
|
||||||
do
|
|
||||||
sleep 1
|
|
||||||
count=$(($count+1))
|
|
||||||
result=`grep -E -A5 -B5 'MS Serving grpc listening on 0.0.0.0:5500' $2_service.log |
|
|
||||||
grep -E 'MS Serving restful listening on 0.0.0.0:5501'|wc -l`
|
|
||||||
done
|
|
||||||
|
|
||||||
if [ ${count} -eq 150 ]
|
|
||||||
then
|
|
||||||
clean_pid
|
|
||||||
cat $2_service.log
|
|
||||||
echo "start serving service failed!" && exit 1
|
|
||||||
fi
|
|
||||||
echo "### start serving service end ###"
|
|
||||||
}
|
|
||||||
|
|
||||||
pytest_serving()
|
|
||||||
{
|
|
||||||
unset http_proxy https_proxy
|
|
||||||
CLIENT_DEVICE_ID=$((${ENV_DEVICE_ID}+1))
|
|
||||||
export DEVICE_ID=${CLIENT_DEVICE_ID}
|
|
||||||
local test_client_name=$1
|
|
||||||
echo "### $1 client start ###"
|
|
||||||
python3 -m pytest -v -s client_example.py::${test_client_name} > ${test_client_name}_client.log 2>&1
|
|
||||||
if [ $? -ne 0 ]
|
|
||||||
then
|
|
||||||
clean_pid
|
|
||||||
cat ${test_client_name}_client.log
|
|
||||||
echo "client $1 faile to start." && exit 1
|
|
||||||
fi
|
|
||||||
echo "### $1 client end ###"
|
|
||||||
}
|
|
||||||
|
|
||||||
test_bert_model()
|
|
||||||
{
|
|
||||||
start_service 5500 bert.mindir ${ENV_DEVICE_ID}
|
|
||||||
pytest_serving test_bert
|
|
||||||
clean_pid
|
|
||||||
}
|
|
||||||
|
|
||||||
echo "-----serving start-----"
|
|
||||||
rm -rf ms_serving *.log *.mindir *.dat ${CURRPATH}/model ${CURRPATH}/kernel_meta
|
|
||||||
prepare_model
|
|
||||||
test_bert_model
|
|
|
@ -1,39 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import pytest
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
|
||||||
@pytest.mark.platform_arm_ascend_training
|
|
||||||
@pytest.mark.env_single
|
|
||||||
def test_serving():
|
|
||||||
"""test_serving"""
|
|
||||||
sh_path = os.path.split(os.path.realpath(__file__))[0]
|
|
||||||
python_path_folders = []
|
|
||||||
for python_path in sys.path:
|
|
||||||
if os.path.isdir(python_path):
|
|
||||||
python_path_folders += [python_path]
|
|
||||||
folders = []
|
|
||||||
for folder in python_path_folders:
|
|
||||||
folders += [os.path.join(folder, x) for x in os.listdir(folder) \
|
|
||||||
if os.path.isdir(os.path.join(folder, x)) and '/site-packages/mindspore' in os.path.join(folder, x)]
|
|
||||||
ret = os.system(f"sh {sh_path}/serving.sh {folders[0].split('mindspore', 1)[0] + 'mindspore'}")
|
|
||||||
assert np.allclose(ret, 0, 0.0001, 0.0001)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_serving()
|
|
|
@ -76,8 +76,6 @@ else()
|
||||||
endif()
|
endif()
|
||||||
endforeach ()
|
endforeach ()
|
||||||
endif()
|
endif()
|
||||||
# serving ut
|
|
||||||
add_subdirectory(serving)
|
|
||||||
|
|
||||||
file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
"../../../mindspore/ccsrc/pybind_api/*.cc"
|
"../../../mindspore/ccsrc/pybind_api/*.cc"
|
||||||
|
@ -151,8 +149,7 @@ add_library(_ut_mindspore_obj OBJECT ${MINDSPORE_SRC_LIST})
|
||||||
add_library(_ut_ut_obj OBJECT ${UT_SRCS})
|
add_library(_ut_ut_obj OBJECT ${UT_SRCS})
|
||||||
add_dependencies(_ut_ut_obj engine-cache-server)
|
add_dependencies(_ut_ut_obj engine-cache-server)
|
||||||
add_executable(ut_tests $<TARGET_OBJECTS:_ut_ut_obj>
|
add_executable(ut_tests $<TARGET_OBJECTS:_ut_ut_obj>
|
||||||
$<TARGET_OBJECTS:_ut_mindspore_obj>
|
$<TARGET_OBJECTS:_ut_mindspore_obj>)
|
||||||
$<TARGET_OBJECTS:_ut_serving_obj>)
|
|
||||||
|
|
||||||
if (ENABLE_GE)
|
if (ENABLE_GE)
|
||||||
if(ENABLE_TRAIN)
|
if(ENABLE_TRAIN)
|
||||||
|
@ -182,14 +179,3 @@ if (USE_GLOG)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
target_link_libraries(ut_tests PRIVATE mindspore mindspore_shared_lib securec graph)
|
target_link_libraries(ut_tests PRIVATE mindspore mindspore_shared_lib securec graph)
|
||||||
|
|
||||||
# link grpc
|
|
||||||
if (EXISTS ${grpc_ROOT}/lib64)
|
|
||||||
set(gRPC_DIR "${grpc_ROOT}/lib64/cmake/grpc")
|
|
||||||
else ()
|
|
||||||
set(gRPC_DIR "${grpc_ROOT}/lib/cmake/grpc")
|
|
||||||
endif ()
|
|
||||||
find_package(gRPC CONFIG REQUIRED)
|
|
||||||
target_link_libraries(ut_tests PRIVATE gRPC::grpc++)
|
|
||||||
target_link_libraries(ut_tests PRIVATE gRPC::grpc++_reflection)
|
|
||||||
target_link_libraries(ut_tests PRIVATE protobuf::libprotobuf)
|
|
||||||
|
|
|
@ -1,89 +0,0 @@
|
||||||
find_package(Threads REQUIRED)
|
|
||||||
|
|
||||||
# This branch assumes that gRPC and all its dependencies are already installed
|
|
||||||
# on this system, so they can be located by find_package().
|
|
||||||
|
|
||||||
# Find Protobuf installation
|
|
||||||
# Looks for protobuf-config.cmake file installed by Protobuf's cmake installation.
|
|
||||||
|
|
||||||
#set(protobuf_MODULE_COMPATIBLE TRUE)
|
|
||||||
#find_package(Protobuf CONFIG REQUIRED)
|
|
||||||
#message(STATUS "Using protobuf ${protobuf_VERSION}")
|
|
||||||
add_library(protobuf::libprotobuf ALIAS protobuf::protobuf)
|
|
||||||
add_executable(protobuf::libprotoc ALIAS protobuf::protoc)
|
|
||||||
|
|
||||||
set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf)
|
|
||||||
if (CMAKE_CROSSCOMPILING)
|
|
||||||
find_program(_PROTOBUF_PROTOC protoc)
|
|
||||||
else ()
|
|
||||||
set(_PROTOBUF_PROTOC $<TARGET_FILE:protobuf::protoc>)
|
|
||||||
endif ()
|
|
||||||
|
|
||||||
# Find gRPC installation
|
|
||||||
# Looks for gRPCConfig.cmake file installed by gRPC's cmake installation.
|
|
||||||
if (EXISTS ${grpc_ROOT}/lib64)
|
|
||||||
set(gRPC_DIR "${grpc_ROOT}/lib64/cmake/grpc")
|
|
||||||
else ()
|
|
||||||
set(gRPC_DIR "${grpc_ROOT}/lib/cmake/grpc")
|
|
||||||
endif ()
|
|
||||||
message("serving ut using grpc_DIR : " ${gPRC_DIR})
|
|
||||||
|
|
||||||
find_package(gRPC CONFIG REQUIRED)
|
|
||||||
message(STATUS "Using gRPC ${gRPC_VERSION}")
|
|
||||||
|
|
||||||
set(_GRPC_GRPCPP gRPC::grpc++)
|
|
||||||
set(_REFLECTION gRPC::grpc++_reflection)
|
|
||||||
|
|
||||||
if (CMAKE_CROSSCOMPILING)
|
|
||||||
find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin)
|
|
||||||
find_program(_GRPC_PYTHON_PLUGIN_EXECUTABLE grpc_python_plugin)
|
|
||||||
else ()
|
|
||||||
set(_GRPC_CPP_PLUGIN_EXECUTABLE $<TARGET_FILE:gRPC::grpc_cpp_plugin>)
|
|
||||||
set(_GRPC_PYTHON_PLUGIN_EXECUTABLE $<TARGET_FILE:gRPC::grpc_python_plugin>)
|
|
||||||
endif ()
|
|
||||||
|
|
||||||
# Proto file
|
|
||||||
get_filename_component(hw_proto "ms_service.proto" ABSOLUTE)
|
|
||||||
get_filename_component(hw_proto_path ${hw_proto} PATH)
|
|
||||||
# Generated sources
|
|
||||||
set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.cc")
|
|
||||||
set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.pb.h")
|
|
||||||
set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.cc")
|
|
||||||
set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/ms_service.grpc.pb.h")
|
|
||||||
set(hw_py_pb2 "${CMAKE_CURRENT_BINARY_DIR}/ms_service_pb2.py")
|
|
||||||
set(hw_py_pb2_grpc "${CMAKE_CURRENT_BINARY_DIR}/ms_service_pb2_grpc.py")
|
|
||||||
add_custom_command(
|
|
||||||
OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}" "${hw_py_pb2}" "${hw_py_pb2_grpc}"
|
|
||||||
COMMAND ${_PROTOBUF_PROTOC}
|
|
||||||
ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}"
|
|
||||||
--cpp_out "${CMAKE_CURRENT_BINARY_DIR}"
|
|
||||||
-I "${hw_proto_path}"
|
|
||||||
--plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}"
|
|
||||||
"${hw_proto}"
|
|
||||||
COMMAND ${_PROTOBUF_PROTOC}
|
|
||||||
ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}"
|
|
||||||
--python_out "${CMAKE_CURRENT_BINARY_DIR}"
|
|
||||||
-I "${hw_proto_path}"
|
|
||||||
--plugin=protoc-gen-grpc="${_GRPC_PYTHON_PLUGIN_EXECUTABLE}"
|
|
||||||
"${hw_proto}"
|
|
||||||
DEPENDS "${hw_proto}")
|
|
||||||
|
|
||||||
list(APPEND SERVING_SRC_TEST ${hw_proto_srcs} ${hw_grpc_srcs})
|
|
||||||
|
|
||||||
file(GLOB_RECURSE ACL_SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|
||||||
"../../../../serving/acl/*.cc"
|
|
||||||
"../../../../serving/core/*.cc")
|
|
||||||
list(APPEND SERVING_SRC_TEST ${ACL_SESSION_SRC_LIST})
|
|
||||||
|
|
||||||
# utest files
|
|
||||||
file(GLOB_RECURSE ACL_UTEST_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
|
||||||
list(APPEND SERVING_SRC_TEST ${ACL_UTEST_SRC_LIST})
|
|
||||||
|
|
||||||
include_directories(${CMAKE_SOURCE_DIR}/serving/core)
|
|
||||||
include_directories(${CMAKE_SOURCE_DIR}/serving/acl)
|
|
||||||
include_directories(${CMAKE_SOURCE_DIR}/serving)
|
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
|
||||||
include_directories(${CMAKE_CURRENT_BINARY_DIR})
|
|
||||||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/../)
|
|
||||||
add_library(_ut_serving_obj OBJECT ${SERVING_SRC_TEST})
|
|
||||||
add_compile_definitions(ENABLE_DVPP_INTERFACE)
|
|
|
@ -1,163 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "acl_session_test_common.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace serving {
|
|
||||||
|
|
||||||
class AclSessionAddTest : public AclSessionTest {
|
|
||||||
public:
|
|
||||||
AclSessionAddTest() = default;
|
|
||||||
void SetUp() override {
|
|
||||||
AclSessionTest::SetUp();
|
|
||||||
aclmdlDesc model_desc;
|
|
||||||
model_desc.inputs.push_back(
|
|
||||||
AclTensorDesc{.dims = {2, 24, 24, 3}, .data_type = ACL_FLOAT, .size = 2 * 24 * 24 * 3 * sizeof(float)});
|
|
||||||
|
|
||||||
model_desc.inputs.push_back(
|
|
||||||
AclTensorDesc{.dims = {2, 24, 24, 3}, .data_type = ACL_FLOAT, .size = 2 * 24 * 24 * 3 * sizeof(float)});
|
|
||||||
|
|
||||||
model_desc.outputs.push_back(
|
|
||||||
AclTensorDesc{.dims = {2, 24, 24, 3}, .data_type = ACL_FLOAT, .size = 2 * 24 * 24 * 3 * sizeof(float)});
|
|
||||||
|
|
||||||
mock_model_desc_ = MockModelDesc(model_desc);
|
|
||||||
g_acl_model_desc = &mock_model_desc_;
|
|
||||||
g_acl_model = &add_mock_model_;
|
|
||||||
}
|
|
||||||
void CreateDefaultRequest(PredictRequest &request) {
|
|
||||||
auto input0 = request.add_data();
|
|
||||||
CreateTensor(*input0, {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32);
|
|
||||||
auto input1 = request.add_data();
|
|
||||||
CreateTensor(*input1, {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32);
|
|
||||||
|
|
||||||
auto input0_data = reinterpret_cast<float *>(input0->mutable_data()->data());
|
|
||||||
auto input1_data = reinterpret_cast<float *>(input1->mutable_data()->data());
|
|
||||||
for (int i = 0; i < 2 * 24 * 24 * 3; i++) {
|
|
||||||
input0_data[i] = i % 1024;
|
|
||||||
input1_data[i] = i % 1024 + 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void CheckDefaultReply(const PredictReply &reply) {
|
|
||||||
EXPECT_TRUE(reply.result().size() == 1);
|
|
||||||
if (reply.result().size() == 1) {
|
|
||||||
CheckTensorItem(reply.result(0), {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32);
|
|
||||||
auto &output = reply.result(0).data();
|
|
||||||
EXPECT_EQ(output.size(), 2 * 24 * 24 * 3 * sizeof(float));
|
|
||||||
if (output.size() == 2 * 24 * 24 * 3 * sizeof(float)) {
|
|
||||||
auto output_data = reinterpret_cast<const float *>(output.data());
|
|
||||||
for (int i = 0; i < 2 * 24 * 24 * 3; i++) {
|
|
||||||
EXPECT_EQ(output_data[i], (i % 1024) + (i % 1024 + 1));
|
|
||||||
if (output_data[i] != (i % 1024) + (i % 1024 + 1)) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
MockModelDesc mock_model_desc_;
|
|
||||||
AddMockAclModel add_mock_model_;
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionAddTest, TestAclSession_OneTime_Success) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
// create inputs
|
|
||||||
PredictRequest request;
|
|
||||||
CreateDefaultRequest(request);
|
|
||||||
|
|
||||||
PredictReply reply;
|
|
||||||
ServingRequest serving_request(request);
|
|
||||||
ServingReply serving_reply(reply);
|
|
||||||
EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS);
|
|
||||||
CheckDefaultReply(reply);
|
|
||||||
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionAddTest, TestAclSession_MutilTimes_Success) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
for (int i = 0; i < 10; i++) {
|
|
||||||
// create inputs
|
|
||||||
PredictRequest request;
|
|
||||||
CreateDefaultRequest(request);
|
|
||||||
|
|
||||||
PredictReply reply;
|
|
||||||
ServingRequest serving_request(request);
|
|
||||||
ServingReply serving_reply(reply);
|
|
||||||
EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS);
|
|
||||||
CheckDefaultReply(reply);
|
|
||||||
}
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionAddTest, TestAclSession_DeviceRunMode_OneTime_Success) {
|
|
||||||
SetDeviceRunMode();
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
// create inputs
|
|
||||||
PredictRequest request;
|
|
||||||
CreateDefaultRequest(request);
|
|
||||||
|
|
||||||
PredictReply reply;
|
|
||||||
ServingRequest serving_request(request);
|
|
||||||
ServingReply serving_reply(reply);
|
|
||||||
EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS);
|
|
||||||
CheckDefaultReply(reply);
|
|
||||||
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionAddTest, TestAclSession_DeviceRunMode_MutilTimes_Success) {
|
|
||||||
SetDeviceRunMode();
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
for (int i = 0; i < 10; i++) {
|
|
||||||
// create inputs
|
|
||||||
PredictRequest request;
|
|
||||||
CreateDefaultRequest(request);
|
|
||||||
|
|
||||||
PredictReply reply;
|
|
||||||
ServingRequest serving_request(request);
|
|
||||||
ServingReply serving_reply(reply);
|
|
||||||
EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS);
|
|
||||||
CheckDefaultReply(reply);
|
|
||||||
}
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace serving
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,192 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_ACL_SESSION_TEST_COMMON_H
|
|
||||||
#define MINDSPORE_ACL_SESSION_TEST_COMMON_H
|
|
||||||
|
|
||||||
#include "common/common_test.h"
|
|
||||||
#include "serving/core/server.h"
|
|
||||||
#include "serving/core/session.h"
|
|
||||||
#include "include/inference.h"
|
|
||||||
#include "include/infer_tensor.h"
|
|
||||||
#include "serving/core/serving_tensor.h"
|
|
||||||
#include "serving/acl/acl_session.h"
|
|
||||||
#include "serving/acl/model_process.h"
|
|
||||||
#include "serving/acl/dvpp_process.h"
|
|
||||||
#include "acl_stub.h"
|
|
||||||
|
|
||||||
class MockDeviceRunMode : public AclRunMode {
|
|
||||||
public:
|
|
||||||
aclError aclrtGetRunMode(aclrtRunMode *runMode) override {
|
|
||||||
*runMode = aclrtRunMode::ACL_DEVICE;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class AclSessionTest : public testing::Test {
|
|
||||||
public:
|
|
||||||
AclSessionTest() = default;
|
|
||||||
void SetUp() override {
|
|
||||||
g_acl_data_buffer = &g_acl_data_buffer_default;
|
|
||||||
g_acl_env = &g_acl_env_default;
|
|
||||||
g_acl_dataset = &g_acl_dataset_default;
|
|
||||||
g_acl_model = &g_acl_model_default;
|
|
||||||
g_acl_model_desc = &g_acl_model_desc_default;
|
|
||||||
g_acl_device_context_stream = &g_acl_device_context_stream_default;
|
|
||||||
g_acl_memory = &g_acl_memory_default;
|
|
||||||
g_acl_dvpp_pic_desc = &g_acl_dvpp_pic_desc_default;
|
|
||||||
g_acl_dvpp_roi_config = &g_acl_dvpp_roi_config_default;
|
|
||||||
g_acl_dvpp_resize_config = &g_acl_dvpp_resize_config_default;
|
|
||||||
g_acl_dvpp_channel_desc = &g_acl_dvpp_channel_desc_default;
|
|
||||||
g_acl_dvpp_process = &g_acl_dvpp_process_default;
|
|
||||||
g_acl_run_mode = &acl_run_mode_default;
|
|
||||||
g_acl_jpeg_lib = &acl_jpeg_lib_default;
|
|
||||||
}
|
|
||||||
void TearDown() override {
|
|
||||||
EXPECT_TRUE(g_acl_data_buffer->Check());
|
|
||||||
EXPECT_TRUE(g_acl_env->Check());
|
|
||||||
EXPECT_TRUE(g_acl_dataset->Check());
|
|
||||||
EXPECT_TRUE(g_acl_model->Check());
|
|
||||||
EXPECT_TRUE(g_acl_model_desc->Check());
|
|
||||||
EXPECT_TRUE(g_acl_device_context_stream->Check());
|
|
||||||
EXPECT_TRUE(g_acl_memory->Check());
|
|
||||||
EXPECT_TRUE(g_acl_dvpp_pic_desc->Check());
|
|
||||||
EXPECT_TRUE(g_acl_dvpp_roi_config->Check());
|
|
||||||
EXPECT_TRUE(g_acl_dvpp_resize_config->Check());
|
|
||||||
EXPECT_TRUE(g_acl_dvpp_channel_desc->Check());
|
|
||||||
EXPECT_TRUE(g_acl_dvpp_process->Check());
|
|
||||||
EXPECT_TRUE(g_acl_jpeg_lib->Check());
|
|
||||||
}
|
|
||||||
|
|
||||||
AclDataBuffer g_acl_data_buffer_default;
|
|
||||||
AclEnv g_acl_env_default;
|
|
||||||
AclDataSet g_acl_dataset_default;
|
|
||||||
AclModel g_acl_model_default;
|
|
||||||
AclModelDesc g_acl_model_desc_default;
|
|
||||||
AclDeviceContextStream g_acl_device_context_stream_default;
|
|
||||||
AclMemory g_acl_memory_default;
|
|
||||||
AclDvppPicDesc g_acl_dvpp_pic_desc_default;
|
|
||||||
AclDvppRoiConfig g_acl_dvpp_roi_config_default;
|
|
||||||
AclDvppResizeConfig g_acl_dvpp_resize_config_default;
|
|
||||||
AclDvppChannelDesc g_acl_dvpp_channel_desc_default;
|
|
||||||
AclDvppProcess g_acl_dvpp_process_default;
|
|
||||||
AclRunMode acl_run_mode_default;
|
|
||||||
MockDeviceRunMode acl_device_run_mode;
|
|
||||||
AclJpegLib acl_jpeg_lib_default = AclJpegLib(0, 0);
|
|
||||||
|
|
||||||
void SetDeviceRunMode() { g_acl_run_mode = &acl_device_run_mode; }
|
|
||||||
void CreateTensor(ms_serving::Tensor &tensor, const std::vector<int64_t> &shape, ms_serving::DataType data_type,
|
|
||||||
std::size_t data_size = INT64_MAX) {
|
|
||||||
if (data_size == INT64_MAX) {
|
|
||||||
data_size = GetDataTypeSize(data_type);
|
|
||||||
for (auto item : shape) {
|
|
||||||
data_size *= item;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tensor.set_data(std::string(data_size, 0));
|
|
||||||
tensor.set_tensor_type(data_type);
|
|
||||||
auto tensor_shape = tensor.mutable_tensor_shape();
|
|
||||||
for (auto item : shape) {
|
|
||||||
tensor_shape->add_dims(item);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t GetDataTypeSize(ms_serving::DataType data_type) {
|
|
||||||
const std::map<ms_serving::DataType, size_t> type_size_map{
|
|
||||||
{ms_serving::DataType::MS_BOOL, sizeof(bool)}, {ms_serving::DataType::MS_INT8, sizeof(int8_t)},
|
|
||||||
{ms_serving::DataType::MS_UINT8, sizeof(uint8_t)}, {ms_serving::DataType::MS_INT16, sizeof(int16_t)},
|
|
||||||
{ms_serving::DataType::MS_UINT16, sizeof(uint16_t)}, {ms_serving::DataType::MS_INT32, sizeof(int32_t)},
|
|
||||||
{ms_serving::DataType::MS_UINT32, sizeof(uint32_t)}, {ms_serving::DataType::MS_INT64, sizeof(int64_t)},
|
|
||||||
{ms_serving::DataType::MS_UINT64, sizeof(uint64_t)}, {ms_serving::DataType::MS_FLOAT16, 2},
|
|
||||||
{ms_serving::DataType::MS_FLOAT32, sizeof(float)}, {ms_serving::DataType::MS_FLOAT64, sizeof(double)},
|
|
||||||
};
|
|
||||||
auto it = type_size_map.find(data_type);
|
|
||||||
if (it == type_size_map.end()) {
|
|
||||||
EXPECT_TRUE(false);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
void CheckTensorItem(const ms_serving::Tensor &tensor, const std::vector<int64_t> &expect_shape,
|
|
||||||
ms_serving::DataType expect_data_type) {
|
|
||||||
std::vector<int64_t> tensor_shape;
|
|
||||||
for (auto item : tensor.tensor_shape().dims()) {
|
|
||||||
tensor_shape.push_back(item);
|
|
||||||
}
|
|
||||||
EXPECT_EQ(expect_shape, tensor_shape);
|
|
||||||
EXPECT_EQ(expect_data_type, tensor.tensor_type());
|
|
||||||
int64_t elem_cnt = 1;
|
|
||||||
for (auto item : expect_shape) {
|
|
||||||
elem_cnt *= item;
|
|
||||||
}
|
|
||||||
auto data_size = GetDataTypeSize(expect_data_type);
|
|
||||||
EXPECT_EQ(data_size * elem_cnt, tensor.data().size());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class MockModelDesc : public AclModelDesc {
|
|
||||||
public:
|
|
||||||
MockModelDesc() {}
|
|
||||||
MockModelDesc(const aclmdlDesc &mock_model_desc) : mock_model_desc_(mock_model_desc) {}
|
|
||||||
aclmdlDesc *aclmdlCreateDesc() override {
|
|
||||||
aclmdlDesc *model_desc = AclModelDesc::aclmdlCreateDesc();
|
|
||||||
*model_desc = mock_model_desc_;
|
|
||||||
return model_desc;
|
|
||||||
}
|
|
||||||
aclmdlDesc mock_model_desc_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class AddMockAclModel : public AclModel {
|
|
||||||
public:
|
|
||||||
aclError aclmdlExecute(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output) override {
|
|
||||||
if (AclModel::aclmdlExecute(modelId, input, output) != ACL_ERROR_NONE) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (input->data_buffers.size() != 2) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
auto &input0 = input->data_buffers[0];
|
|
||||||
auto &input1 = input->data_buffers[1];
|
|
||||||
std::size_t expect_count = input0->size / sizeof(float);
|
|
||||||
if (input0->size != expect_count * sizeof(float) || input1->size != expect_count * sizeof(float)) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (output->data_buffers.size() != 1) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
auto &output0 = output->data_buffers[0];
|
|
||||||
if (output0->size != expect_count * sizeof(float)) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto input0_data = reinterpret_cast<const float *>(input0->data);
|
|
||||||
auto input1_data = reinterpret_cast<const float *>(input1->data);
|
|
||||||
auto output0_data = reinterpret_cast<float *>(output0->data);
|
|
||||||
for (size_t i = 0; i < expect_count; i++) {
|
|
||||||
output0_data[i] = input0_data[i] + input1_data[i];
|
|
||||||
}
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclmdlExecuteAsync(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output,
|
|
||||||
aclrtStream stream) override {
|
|
||||||
return aclmdlExecute(modelId, input, output);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // MINDSPORE_ACL_SESSION_TEST_COMMON_H
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,342 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "acl_session_test_common.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace serving {
|
|
||||||
|
|
||||||
class MockFailAclDeviceContextStream : public AclDeviceContextStream {
|
|
||||||
public:
|
|
||||||
aclError aclrtSetDevice(int32_t deviceId) override {
|
|
||||||
if (set_device_fail_list_.empty()) {
|
|
||||||
return AclDeviceContextStream::aclrtSetDevice(deviceId);
|
|
||||||
}
|
|
||||||
auto val = set_device_fail_list_.front();
|
|
||||||
set_device_fail_list_.erase(set_device_fail_list_.begin());
|
|
||||||
if (val) {
|
|
||||||
return AclDeviceContextStream::aclrtSetDevice(deviceId);
|
|
||||||
}
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclrtResetDevice(int32_t deviceId) override {
|
|
||||||
auto ret = AclDeviceContextStream::aclrtResetDevice(deviceId);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
if (reset_device_fail_list_.empty()) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
auto val = reset_device_fail_list_.front();
|
|
||||||
reset_device_fail_list_.erase(reset_device_fail_list_.begin());
|
|
||||||
return val ? ACL_ERROR_NONE : 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclrtCreateContext(aclrtContext *context, int32_t deviceId) override {
|
|
||||||
if (create_context_fail_list_.empty()) {
|
|
||||||
return AclDeviceContextStream::aclrtCreateContext(context, deviceId);
|
|
||||||
}
|
|
||||||
auto val = create_context_fail_list_.front();
|
|
||||||
create_context_fail_list_.erase(create_context_fail_list_.begin());
|
|
||||||
if (val) {
|
|
||||||
return AclDeviceContextStream::aclrtCreateContext(context, deviceId);
|
|
||||||
}
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclrtDestroyContext(aclrtContext context) override {
|
|
||||||
auto ret = AclDeviceContextStream::aclrtDestroyContext(context);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
if (destroy_context_fail_list_.empty()) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
auto val = destroy_context_fail_list_.front();
|
|
||||||
destroy_context_fail_list_.erase(destroy_context_fail_list_.begin());
|
|
||||||
return val ? ACL_ERROR_NONE : 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclrtCreateStream(aclrtStream *stream) override {
|
|
||||||
if (create_stream_fail_list_.empty()) {
|
|
||||||
return AclDeviceContextStream::aclrtCreateStream(stream);
|
|
||||||
}
|
|
||||||
auto val = create_stream_fail_list_.front();
|
|
||||||
create_stream_fail_list_.erase(create_stream_fail_list_.begin());
|
|
||||||
if (val) {
|
|
||||||
return AclDeviceContextStream::aclrtCreateStream(stream);
|
|
||||||
}
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclrtDestroyStream(aclrtStream stream) override {
|
|
||||||
auto ret = AclDeviceContextStream::aclrtDestroyStream(stream);
|
|
||||||
if (ret != ACL_ERROR_NONE) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
if (destroy_stream_fail_list_.empty()) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
auto val = destroy_stream_fail_list_.front();
|
|
||||||
destroy_stream_fail_list_.erase(destroy_stream_fail_list_.begin());
|
|
||||||
return val ? ACL_ERROR_NONE : 1;
|
|
||||||
}
|
|
||||||
std::vector<bool> set_device_fail_list_;
|
|
||||||
std::vector<bool> reset_device_fail_list_;
|
|
||||||
std::vector<bool> create_context_fail_list_;
|
|
||||||
std::vector<bool> destroy_context_fail_list_;
|
|
||||||
std::vector<bool> create_stream_fail_list_;
|
|
||||||
std::vector<bool> destroy_stream_fail_list_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class MockFailAclMemory : public AclMemory {
|
|
||||||
public:
|
|
||||||
aclError aclrtMalloc(void **devPtr, size_t size, aclrtMemMallocPolicy policy) override {
|
|
||||||
if (device_mem_fail_list_.empty()) {
|
|
||||||
return AclMemory::aclrtMalloc(devPtr, size, policy);
|
|
||||||
}
|
|
||||||
auto val = device_mem_fail_list_.front();
|
|
||||||
device_mem_fail_list_.erase(device_mem_fail_list_.begin());
|
|
||||||
if (val) {
|
|
||||||
return AclMemory::aclrtMalloc(devPtr, size, policy);
|
|
||||||
}
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
aclError aclrtMallocHost(void **hostPtr, size_t size) override {
|
|
||||||
if (host_mem_fail_list_.empty()) {
|
|
||||||
return AclMemory::aclrtMallocHost(hostPtr, size);
|
|
||||||
}
|
|
||||||
auto val = host_mem_fail_list_.front();
|
|
||||||
host_mem_fail_list_.erase(host_mem_fail_list_.begin());
|
|
||||||
if (val) {
|
|
||||||
return AclMemory::aclrtMallocHost(hostPtr, size);
|
|
||||||
}
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
aclError acldvppMalloc(void **devPtr, size_t size) override {
|
|
||||||
if (dvpp_mem_fail_list_.empty()) {
|
|
||||||
return AclMemory::acldvppMalloc(devPtr, size);
|
|
||||||
}
|
|
||||||
auto val = dvpp_mem_fail_list_.front();
|
|
||||||
dvpp_mem_fail_list_.erase(dvpp_mem_fail_list_.begin());
|
|
||||||
if (val) {
|
|
||||||
return AclMemory::acldvppMalloc(devPtr, size);
|
|
||||||
}
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<bool> device_mem_fail_list_;
|
|
||||||
std::vector<bool> host_mem_fail_list_;
|
|
||||||
std::vector<bool> dvpp_mem_fail_list_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class AclSessionModelLoadTest : public AclSessionTest {
|
|
||||||
public:
|
|
||||||
AclSessionModelLoadTest() = default;
|
|
||||||
void SetUp() override {
|
|
||||||
AclSessionTest::SetUp();
|
|
||||||
aclmdlDesc model_desc;
|
|
||||||
model_desc.inputs.push_back(
|
|
||||||
AclTensorDesc{.dims = {2, 24, 24, 3}, .data_type = ACL_FLOAT, .size = 2 * 24 * 24 * 3 * sizeof(float)});
|
|
||||||
|
|
||||||
model_desc.inputs.push_back(
|
|
||||||
AclTensorDesc{.dims = {2, 24, 24, 3}, .data_type = ACL_FLOAT, .size = 2 * 24 * 24 * 3 * sizeof(float)});
|
|
||||||
|
|
||||||
model_desc.outputs.push_back(
|
|
||||||
AclTensorDesc{.dims = {2, 24, 24, 3}, .data_type = ACL_FLOAT, .size = 2 * 24 * 24 * 3 * sizeof(float)});
|
|
||||||
|
|
||||||
model_desc.outputs.push_back(
|
|
||||||
AclTensorDesc{.dims = {2, 24, 24, 3}, .data_type = ACL_FLOAT, .size = 2 * 24 * 24 * 3 * sizeof(float)});
|
|
||||||
|
|
||||||
mock_model_desc_ = MockModelDesc(model_desc);
|
|
||||||
g_acl_model_desc = &mock_model_desc_;
|
|
||||||
g_acl_device_context_stream = &fail_acl_device_context_stream_;
|
|
||||||
g_acl_memory = &fail_acl_memory_;
|
|
||||||
}
|
|
||||||
void CreateDefaultRequest(PredictRequest &request) {
|
|
||||||
auto input0 = request.add_data();
|
|
||||||
CreateTensor(*input0, {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32);
|
|
||||||
auto input1 = request.add_data();
|
|
||||||
CreateTensor(*input1, {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CheckDefaultReply(const PredictReply &reply) {
|
|
||||||
EXPECT_TRUE(reply.result().size() == 2);
|
|
||||||
if (reply.result().size() == 2) {
|
|
||||||
CheckTensorItem(reply.result(0), {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32);
|
|
||||||
CheckTensorItem(reply.result(1), {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
MockModelDesc mock_model_desc_;
|
|
||||||
/* Test Resource will be release on something wrong happens*/
|
|
||||||
MockFailAclDeviceContextStream fail_acl_device_context_stream_;
|
|
||||||
MockFailAclMemory fail_acl_memory_;
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionModelLoadTest, TestAclSession_OneTime_Success) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
// create inputs
|
|
||||||
PredictRequest request;
|
|
||||||
CreateDefaultRequest(request);
|
|
||||||
|
|
||||||
PredictReply reply;
|
|
||||||
ServingRequest serving_request(request);
|
|
||||||
ServingReply serving_reply(reply);
|
|
||||||
EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS);
|
|
||||||
CheckDefaultReply(reply);
|
|
||||||
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionModelLoadTest, TestAclSession_SetDeviceFail) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
fail_acl_device_context_stream_.set_device_fail_list_.push_back(false);
|
|
||||||
EXPECT_FALSE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionModelLoadTest, TestAclSession_CreateContextFail) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
fail_acl_device_context_stream_.create_context_fail_list_.push_back(false);
|
|
||||||
EXPECT_FALSE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionModelLoadTest, TestAclSession_CreateStreamFail) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
fail_acl_device_context_stream_.create_stream_fail_list_.push_back(false);
|
|
||||||
EXPECT_FALSE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionModelLoadTest, TestAclSession_ResetDeviceFail) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
fail_acl_device_context_stream_.reset_device_fail_list_.push_back(false);
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
acl_session.FinalizeEnv();
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionModelLoadTest, TestAclSession_DestroyContextFail) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
fail_acl_device_context_stream_.destroy_context_fail_list_.push_back(false);
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
acl_session.FinalizeEnv();
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionModelLoadTest, TestAclSession_DestroyStreamFail) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
fail_acl_device_context_stream_.destroy_stream_fail_list_.push_back(false);
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
acl_session.FinalizeEnv();
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionModelLoadTest, TestAclSession_MallocFail0_Success) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
fail_acl_memory_.device_mem_fail_list_.push_back(false); // input0 buffer
|
|
||||||
EXPECT_FALSE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionModelLoadTest, TestAclSession_MallocFail1_Success) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
fail_acl_memory_.device_mem_fail_list_.push_back(true); // input0 buffer
|
|
||||||
fail_acl_memory_.device_mem_fail_list_.push_back(false); // input1 buffer
|
|
||||||
EXPECT_FALSE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionModelLoadTest, TestAclSession_MallocFail2_Success) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
fail_acl_memory_.device_mem_fail_list_.push_back(true); // input0 buffer
|
|
||||||
fail_acl_memory_.device_mem_fail_list_.push_back(true); // input1 buffer
|
|
||||||
fail_acl_memory_.device_mem_fail_list_.push_back(false); // output0 buffer
|
|
||||||
EXPECT_FALSE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionModelLoadTest, TestAclSession_MallocFail3_Success) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
fail_acl_memory_.device_mem_fail_list_.push_back(true); // input0 buffer
|
|
||||||
fail_acl_memory_.device_mem_fail_list_.push_back(true); // input1 buffer
|
|
||||||
fail_acl_memory_.device_mem_fail_list_.push_back(true); // output0 buffer
|
|
||||||
fail_acl_memory_.device_mem_fail_list_.push_back(false); // output1 buffer
|
|
||||||
EXPECT_FALSE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionModelLoadTest, TestAclSession_RunOnDevice_MallocFail0_Success) {
|
|
||||||
SetDeviceRunMode();
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
fail_acl_memory_.host_mem_fail_list_.push_back(false); // output0 buffer
|
|
||||||
EXPECT_FALSE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionModelLoadTest, TestAclSession_RunOnDevice_MallocFail1_Success) {
|
|
||||||
SetDeviceRunMode();
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
fail_acl_memory_.host_mem_fail_list_.push_back(true); // output0 buffer
|
|
||||||
fail_acl_memory_.host_mem_fail_list_.push_back(false); // output1 buffer
|
|
||||||
EXPECT_FALSE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace serving
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,138 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "acl_session_test_common.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace serving {
|
|
||||||
|
|
||||||
class AclSessionOneInputOneOutputTest : public AclSessionTest {
|
|
||||||
public:
|
|
||||||
AclSessionOneInputOneOutputTest() = default;
|
|
||||||
void SetUp() override {
|
|
||||||
AclSessionTest::SetUp();
|
|
||||||
aclmdlDesc model_desc;
|
|
||||||
model_desc.inputs.push_back(
|
|
||||||
AclTensorDesc{.dims = {2, 24, 24, 3}, .data_type = ACL_FLOAT, .size = 2 * 24 * 24 * 3 * sizeof(float)});
|
|
||||||
model_desc.outputs.push_back(
|
|
||||||
AclTensorDesc{.dims = {2, 8, 8, 3}, .data_type = ACL_FLOAT, .size = 2 * 8 * 8 * 3 * sizeof(float)});
|
|
||||||
mock_model_desc_ = MockModelDesc(model_desc);
|
|
||||||
g_acl_model_desc = &mock_model_desc_;
|
|
||||||
}
|
|
||||||
void CreateDefaultRequest(PredictRequest &request) {
|
|
||||||
auto input0 = request.add_data();
|
|
||||||
CreateTensor(*input0, {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CreateInvalidDataSizeRequest(PredictRequest &request) {
|
|
||||||
auto input0 = request.add_data();
|
|
||||||
// data size invalid, not match model input required
|
|
||||||
CreateTensor(*input0, {2, 24, 24, 2}, ::ms_serving::DataType::MS_FLOAT32);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CheckDefaultReply(const PredictReply &reply) {
|
|
||||||
EXPECT_TRUE(reply.result().size() == 1);
|
|
||||||
if (reply.result().size() == 1) {
|
|
||||||
CheckTensorItem(reply.result(0), {2, 8, 8, 3}, ::ms_serving::DataType::MS_FLOAT32);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
MockModelDesc mock_model_desc_;
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionOneInputOneOutputTest, TestAclSession_OneTime_Success) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
// create inputs
|
|
||||||
PredictRequest request;
|
|
||||||
CreateDefaultRequest(request);
|
|
||||||
|
|
||||||
PredictReply reply;
|
|
||||||
ServingRequest serving_request(request);
|
|
||||||
ServingReply serving_reply(reply);
|
|
||||||
EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS);
|
|
||||||
CheckDefaultReply(reply);
|
|
||||||
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionOneInputOneOutputTest, TestAclSession_MutilTimes_Success) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
for (int i = 0; i < 10; i++) {
|
|
||||||
// create inputs
|
|
||||||
PredictRequest request;
|
|
||||||
CreateDefaultRequest(request);
|
|
||||||
|
|
||||||
PredictReply reply;
|
|
||||||
ServingRequest serving_request(request);
|
|
||||||
ServingReply serving_reply(reply);
|
|
||||||
EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS);
|
|
||||||
CheckDefaultReply(reply);
|
|
||||||
}
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionOneInputOneOutputTest, TestAclSession_InvalidDataSize_Fail) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
// create inputs
|
|
||||||
PredictRequest request;
|
|
||||||
CreateInvalidDataSizeRequest(request);
|
|
||||||
|
|
||||||
PredictReply reply;
|
|
||||||
ServingRequest serving_request(request);
|
|
||||||
ServingReply serving_reply(reply);
|
|
||||||
EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionOneInputOneOutputTest, TestAclSession_InvalidDataSize_MultiTimes_Fail) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
for (int i = 0; i < 10; i++) {
|
|
||||||
// create inputs
|
|
||||||
PredictRequest request;
|
|
||||||
CreateInvalidDataSizeRequest(request);
|
|
||||||
|
|
||||||
PredictReply reply;
|
|
||||||
ServingRequest serving_request(request);
|
|
||||||
ServingReply serving_reply(reply);
|
|
||||||
EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS);
|
|
||||||
}
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace serving
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,226 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "acl_session_test_common.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace serving {
|
|
||||||
|
|
||||||
class AclSessionTwoInputTwoOutputTest : public AclSessionTest {
|
|
||||||
public:
|
|
||||||
AclSessionTwoInputTwoOutputTest() = default;
|
|
||||||
void SetUp() override {
|
|
||||||
AclSessionTest::SetUp();
|
|
||||||
aclmdlDesc model_desc;
|
|
||||||
model_desc.inputs.push_back(
|
|
||||||
AclTensorDesc{.dims = {2, 24, 24, 3}, .data_type = ACL_FLOAT, .size = 2 * 24 * 24 * 3 * sizeof(float)});
|
|
||||||
|
|
||||||
model_desc.inputs.push_back(
|
|
||||||
AclTensorDesc{.dims = {2, 32}, .data_type = ACL_INT32, .size = 2 * 32 * sizeof(int32_t)});
|
|
||||||
|
|
||||||
model_desc.outputs.push_back(
|
|
||||||
AclTensorDesc{.dims = {2, 8, 8, 3}, .data_type = ACL_FLOAT, .size = 2 * 8 * 8 * 3 * sizeof(float)});
|
|
||||||
|
|
||||||
model_desc.outputs.push_back(
|
|
||||||
AclTensorDesc{.dims = {2, 1024}, .data_type = ACL_BOOL, .size = 2 * 1024 * sizeof(bool)});
|
|
||||||
|
|
||||||
mock_model_desc_ = MockModelDesc(model_desc);
|
|
||||||
g_acl_model_desc = &mock_model_desc_;
|
|
||||||
}
|
|
||||||
void CreateDefaultRequest(PredictRequest &request) {
|
|
||||||
auto input0 = request.add_data();
|
|
||||||
CreateTensor(*input0, {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32);
|
|
||||||
auto input1 = request.add_data();
|
|
||||||
CreateTensor(*input1, {2, 32}, ::ms_serving::DataType::MS_INT32);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CreateInvalidDataSizeRequest0(PredictRequest &request) {
|
|
||||||
auto input0 = request.add_data();
|
|
||||||
// data size invalid, not match model input required
|
|
||||||
CreateTensor(*input0, {2, 24, 24, 2}, ::ms_serving::DataType::MS_FLOAT32);
|
|
||||||
|
|
||||||
auto input1 = request.add_data();
|
|
||||||
CreateTensor(*input1, {2, 32}, ::ms_serving::DataType::MS_INT32);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CreateInvalidDataSizeRequest1(PredictRequest &request) {
|
|
||||||
auto input0 = request.add_data();
|
|
||||||
CreateTensor(*input0, {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32);
|
|
||||||
auto input1 = request.add_data();
|
|
||||||
// data size invalid, not match model input required
|
|
||||||
CreateTensor(*input1, {2, 16}, ::ms_serving::DataType::MS_INT32);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CreateInvalidDataSizeRequestOneInput0(PredictRequest &request) {
|
|
||||||
// only has one input for input0
|
|
||||||
auto input0 = request.add_data();
|
|
||||||
CreateTensor(*input0, {2, 24, 24, 3}, ::ms_serving::DataType::MS_FLOAT32);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CreateInvalidDataSizeRequestOneInput1(PredictRequest &request) {
|
|
||||||
// only has one input for input1
|
|
||||||
auto input0 = request.add_data();
|
|
||||||
CreateTensor(*input0, {2, 32}, ::ms_serving::DataType::MS_INT32);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CheckDefaultReply(const PredictReply &reply) {
|
|
||||||
EXPECT_TRUE(reply.result().size() == 2);
|
|
||||||
if (reply.result().size() == 2) {
|
|
||||||
CheckTensorItem(reply.result(0), {2, 8, 8, 3}, ::ms_serving::DataType::MS_FLOAT32);
|
|
||||||
CheckTensorItem(reply.result(1), {2, 1024}, ::ms_serving::DataType::MS_BOOL);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
MockModelDesc mock_model_desc_;
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionTwoInputTwoOutputTest, TestAclSession_OneTime_Success) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
// create inputs
|
|
||||||
PredictRequest request;
|
|
||||||
CreateDefaultRequest(request);
|
|
||||||
|
|
||||||
PredictReply reply;
|
|
||||||
ServingRequest serving_request(request);
|
|
||||||
ServingReply serving_reply(reply);
|
|
||||||
EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS);
|
|
||||||
CheckDefaultReply(reply);
|
|
||||||
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionTwoInputTwoOutputTest, TestAclSession_MutilTimes_Success) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
for (int i = 0; i < 10; i++) {
|
|
||||||
// create inputs
|
|
||||||
PredictRequest request;
|
|
||||||
CreateDefaultRequest(request);
|
|
||||||
|
|
||||||
PredictReply reply;
|
|
||||||
ServingRequest serving_request(request);
|
|
||||||
ServingReply serving_reply(reply);
|
|
||||||
EXPECT_TRUE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS);
|
|
||||||
CheckDefaultReply(reply);
|
|
||||||
}
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionTwoInputTwoOutputTest, TestAclSession_Input0_InvalidDataSize_Fail) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
// create inputs
|
|
||||||
PredictRequest request;
|
|
||||||
CreateInvalidDataSizeRequest0(request);
|
|
||||||
|
|
||||||
PredictReply reply;
|
|
||||||
ServingRequest serving_request(request);
|
|
||||||
ServingReply serving_reply(reply);
|
|
||||||
EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionTwoInputTwoOutputTest, TestAclSession_Input1_InvalidDataSize_Fail) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
// create inputs
|
|
||||||
PredictRequest request;
|
|
||||||
CreateInvalidDataSizeRequest1(request);
|
|
||||||
|
|
||||||
PredictReply reply;
|
|
||||||
ServingRequest serving_request(request);
|
|
||||||
ServingReply serving_reply(reply);
|
|
||||||
EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionTwoInputTwoOutputTest, TestAclSession_OnlyInput0_Fail) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
// create inputs
|
|
||||||
PredictRequest request;
|
|
||||||
CreateInvalidDataSizeRequestOneInput0(request);
|
|
||||||
|
|
||||||
PredictReply reply;
|
|
||||||
ServingRequest serving_request(request);
|
|
||||||
ServingReply serving_reply(reply);
|
|
||||||
EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionTwoInputTwoOutputTest, TestAclSession_OnlyInput1_Fail) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
// create inputs
|
|
||||||
PredictRequest request;
|
|
||||||
CreateInvalidDataSizeRequestOneInput1(request);
|
|
||||||
|
|
||||||
PredictReply reply;
|
|
||||||
ServingRequest serving_request(request);
|
|
||||||
ServingReply serving_reply(reply);
|
|
||||||
EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(AclSessionTwoInputTwoOutputTest, TestAclSession_InvalidDataSize_MultiTimes_Fail) {
|
|
||||||
inference::AclSession acl_session;
|
|
||||||
uint32_t device_id = 1;
|
|
||||||
EXPECT_TRUE(acl_session.InitEnv("Ascend", device_id) == SUCCESS);
|
|
||||||
uint32_t model_id = 0;
|
|
||||||
EXPECT_TRUE(acl_session.LoadModelFromFile("fake_model_path", model_id) == SUCCESS);
|
|
||||||
for (int i = 0; i < 10; i++) {
|
|
||||||
// create inputs
|
|
||||||
PredictRequest request;
|
|
||||||
CreateInvalidDataSizeRequest0(request);
|
|
||||||
|
|
||||||
PredictReply reply;
|
|
||||||
ServingRequest serving_request(request);
|
|
||||||
ServingReply serving_reply(reply);
|
|
||||||
EXPECT_FALSE(acl_session.ExecuteModel(model_id, serving_request, serving_reply) == SUCCESS);
|
|
||||||
}
|
|
||||||
EXPECT_TRUE(acl_session.UnloadModel(model_id) == SUCCESS);
|
|
||||||
EXPECT_TRUE(acl_session.FinalizeEnv() == SUCCESS);
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace serving
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,323 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "acl_stub.h"
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
AclDataBuffer *g_acl_data_buffer = nullptr;
|
|
||||||
AclEnv *g_acl_env = nullptr;
|
|
||||||
AclDataSet *g_acl_dataset = nullptr;
|
|
||||||
AclModel *g_acl_model = nullptr;
|
|
||||||
AclModelDesc *g_acl_model_desc = nullptr;
|
|
||||||
AclDeviceContextStream *g_acl_device_context_stream = nullptr;
|
|
||||||
AclMemory *g_acl_memory = nullptr;
|
|
||||||
AclDvppPicDesc *g_acl_dvpp_pic_desc = nullptr;
|
|
||||||
AclDvppRoiConfig *g_acl_dvpp_roi_config = nullptr;
|
|
||||||
AclDvppResizeConfig *g_acl_dvpp_resize_config = nullptr;
|
|
||||||
AclDvppChannelDesc *g_acl_dvpp_channel_desc = nullptr;
|
|
||||||
AclDvppProcess *g_acl_dvpp_process = nullptr;
|
|
||||||
AclRunMode *g_acl_run_mode = nullptr;
|
|
||||||
AclJpegLib *g_acl_jpeg_lib = nullptr;
|
|
||||||
|
|
||||||
aclDataBuffer *aclCreateDataBuffer(void *data, size_t size) {
|
|
||||||
return g_acl_data_buffer->aclCreateDataBuffer(data, size);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclDestroyDataBuffer(const aclDataBuffer *dataBuffer) {
|
|
||||||
return g_acl_data_buffer->aclDestroyDataBuffer(dataBuffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
void *aclGetDataBufferAddr(const aclDataBuffer *dataBuffer) {
|
|
||||||
return g_acl_data_buffer->aclGetDataBufferAddr(dataBuffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t aclGetDataBufferSize(const aclDataBuffer *dataBuffer) {
|
|
||||||
return g_acl_data_buffer->aclGetDataBufferSize(dataBuffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t aclDataTypeSize(aclDataType dataType) {
|
|
||||||
std::unordered_map<aclDataType, size_t> dataTypeMap = {
|
|
||||||
{ACL_FLOAT16, 2}, {ACL_FLOAT, 4}, {ACL_DOUBLE, 8}, {ACL_INT8, 1}, {ACL_INT16, 2}, {ACL_INT32, 4},
|
|
||||||
{ACL_INT64, 8}, {ACL_UINT8, 1}, {ACL_UINT16, 2}, {ACL_UINT32, 4}, {ACL_UINT64, 8}, {ACL_BOOL, 1},
|
|
||||||
};
|
|
||||||
auto it = dataTypeMap.find(dataType);
|
|
||||||
if (it == dataTypeMap.end()) {
|
|
||||||
return 0;
|
|
||||||
} else {
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void aclAppLog(aclLogLevel logLevel, const char *func, const char *file, uint32_t line, const char *fmt, ...) {
|
|
||||||
if (logLevel == ACL_ERROR) {
|
|
||||||
// std::cout << file << ":" << line << "," << func << ": " << fmt << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclInit(const char *configPath) { return g_acl_env->aclInit(configPath); }
|
|
||||||
|
|
||||||
aclError aclFinalize() { return g_acl_env->aclFinalize(); }
|
|
||||||
|
|
||||||
// dataset
|
|
||||||
aclmdlDataset *aclmdlCreateDataset() { return g_acl_dataset->aclmdlCreateDataset(); }
|
|
||||||
|
|
||||||
aclError aclmdlDestroyDataset(const aclmdlDataset *dataSet) { return g_acl_dataset->aclmdlDestroyDataset(dataSet); }
|
|
||||||
|
|
||||||
aclError aclmdlAddDatasetBuffer(aclmdlDataset *dataSet, aclDataBuffer *dataBuffer) {
|
|
||||||
return g_acl_dataset->aclmdlAddDatasetBuffer(dataSet, dataBuffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t aclmdlGetDatasetNumBuffers(const aclmdlDataset *dataSet) {
|
|
||||||
return g_acl_dataset->aclmdlGetDatasetNumBuffers(dataSet);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclDataBuffer *aclmdlGetDatasetBuffer(const aclmdlDataset *dataSet, size_t index) {
|
|
||||||
return g_acl_dataset->aclmdlGetDatasetBuffer(dataSet, index);
|
|
||||||
}
|
|
||||||
|
|
||||||
// model
|
|
||||||
aclError aclmdlLoadFromFile(const char *modelPath, uint32_t *modelId) {
|
|
||||||
return g_acl_model->aclmdlLoadFromFile(modelPath, modelId);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclmdlLoadFromMem(const void *model, size_t modelSize, uint32_t *modelId) {
|
|
||||||
return g_acl_model->aclmdlLoadFromMem(model, modelSize, modelId);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclmdlLoadFromFileWithMem(const char *modelPath, uint32_t *modelId, void *workPtr, size_t workSize,
|
|
||||||
void *weightPtr, size_t weightSize) {
|
|
||||||
return g_acl_model->aclmdlLoadFromFileWithMem(modelPath, modelId, workPtr, workSize, weightPtr, weightSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclmdlLoadFromMemWithMem(const void *model, size_t modelSize, uint32_t *modelId, void *workPtr,
|
|
||||||
size_t workSize, void *weightPtr, size_t weightSize) {
|
|
||||||
return g_acl_model->aclmdlLoadFromMemWithMem(model, modelSize, modelId, workPtr, workSize, weightPtr, weightSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclmdlExecute(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output) {
|
|
||||||
return g_acl_model->aclmdlExecute(modelId, input, output);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclmdlExecuteAsync(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output, aclrtStream stream) {
|
|
||||||
return g_acl_model->aclmdlExecuteAsync(modelId, input, output, stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclmdlUnload(uint32_t modelId) { return g_acl_model->aclmdlUnload(modelId); }
|
|
||||||
|
|
||||||
// model desc
|
|
||||||
aclmdlDesc *aclmdlCreateDesc() { return g_acl_model_desc->aclmdlCreateDesc(); }
|
|
||||||
|
|
||||||
aclError aclmdlDestroyDesc(aclmdlDesc *modelDesc) { return g_acl_model_desc->aclmdlDestroyDesc(modelDesc); }
|
|
||||||
|
|
||||||
aclError aclmdlGetDesc(aclmdlDesc *modelDesc, uint32_t modelId) {
|
|
||||||
return g_acl_model_desc->aclmdlGetDesc(modelDesc, modelId);
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t aclmdlGetNumInputs(aclmdlDesc *modelDesc) { return g_acl_model_desc->aclmdlGetNumInputs(modelDesc); }
|
|
||||||
|
|
||||||
size_t aclmdlGetNumOutputs(aclmdlDesc *modelDesc) { return g_acl_model_desc->aclmdlGetNumOutputs(modelDesc); }
|
|
||||||
|
|
||||||
size_t aclmdlGetInputSizeByIndex(aclmdlDesc *modelDesc, size_t index) {
|
|
||||||
return g_acl_model_desc->aclmdlGetInputSizeByIndex(modelDesc, index);
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t aclmdlGetOutputSizeByIndex(aclmdlDesc *modelDesc, size_t index) {
|
|
||||||
return g_acl_model_desc->aclmdlGetOutputSizeByIndex(modelDesc, index);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclmdlGetInputDims(const aclmdlDesc *modelDesc, size_t index, aclmdlIODims *dims) {
|
|
||||||
return g_acl_model_desc->aclmdlGetInputDims(modelDesc, index, dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclmdlGetOutputDims(const aclmdlDesc *modelDesc, size_t index, aclmdlIODims *dims) {
|
|
||||||
return g_acl_model_desc->aclmdlGetOutputDims(modelDesc, index, dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclmdlGetCurOutputDims(const aclmdlDesc *modelDesc, size_t index, aclmdlIODims *dims) {
|
|
||||||
return g_acl_model_desc->aclmdlGetCurOutputDims(modelDesc, index, dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclFormat aclmdlGetInputFormat(const aclmdlDesc *modelDesc, size_t index) {
|
|
||||||
return g_acl_model_desc->aclmdlGetInputFormat(modelDesc, index);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclFormat aclmdlGetOutputFormat(const aclmdlDesc *modelDesc, size_t index) {
|
|
||||||
return g_acl_model_desc->aclmdlGetOutputFormat(modelDesc, index);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclDataType aclmdlGetInputDataType(const aclmdlDesc *modelDesc, size_t index) {
|
|
||||||
return g_acl_model_desc->aclmdlGetInputDataType(modelDesc, index);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclDataType aclmdlGetOutputDataType(const aclmdlDesc *modelDesc, size_t index) {
|
|
||||||
return g_acl_model_desc->aclmdlGetOutputDataType(modelDesc, index);
|
|
||||||
}
|
|
||||||
|
|
||||||
// device, context, stream
|
|
||||||
|
|
||||||
aclError aclrtCreateContext(aclrtContext *context, int32_t deviceId) {
|
|
||||||
return g_acl_device_context_stream->aclrtCreateContext(context, deviceId);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclrtDestroyContext(aclrtContext context) { return g_acl_device_context_stream->aclrtDestroyContext(context); }
|
|
||||||
|
|
||||||
aclError aclrtSetCurrentContext(aclrtContext context) {
|
|
||||||
return g_acl_device_context_stream->aclrtSetCurrentContext(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclrtSetDevice(int32_t deviceId) { return g_acl_device_context_stream->aclrtSetDevice(deviceId); }
|
|
||||||
|
|
||||||
aclError aclrtResetDevice(int32_t deviceId) { return g_acl_device_context_stream->aclrtResetDevice(deviceId); }
|
|
||||||
|
|
||||||
aclError aclrtGetRunMode(aclrtRunMode *runMode) { return g_acl_run_mode->aclrtGetRunMode(runMode); }
|
|
||||||
|
|
||||||
aclError aclrtCreateStream(aclrtStream *stream) { return g_acl_device_context_stream->aclrtCreateStream(stream); }
|
|
||||||
|
|
||||||
aclError aclrtDestroyStream(aclrtStream stream) { return g_acl_device_context_stream->aclrtDestroyStream(stream); }
|
|
||||||
|
|
||||||
aclError aclrtSynchronizeStream(aclrtStream stream) {
|
|
||||||
return g_acl_device_context_stream->aclrtSynchronizeStream(stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
// memory
|
|
||||||
aclError acldvppMalloc(void **devPtr, size_t size) { return g_acl_memory->acldvppMalloc(devPtr, size); }
|
|
||||||
aclError acldvppFree(void *devPtr) { return g_acl_memory->acldvppFree(devPtr); }
|
|
||||||
|
|
||||||
aclError aclrtMalloc(void **devPtr, size_t size, aclrtMemMallocPolicy policy) {
|
|
||||||
return g_acl_memory->aclrtMalloc(devPtr, size, policy);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclrtFree(void *devPtr) { return g_acl_memory->aclrtFree(devPtr); }
|
|
||||||
|
|
||||||
aclError aclrtMallocHost(void **hostPtr, size_t size) { return g_acl_memory->aclrtMallocHost(hostPtr, size); }
|
|
||||||
|
|
||||||
aclError aclrtFreeHost(void *hostPtr) { return g_acl_memory->aclrtFreeHost(hostPtr); }
|
|
||||||
|
|
||||||
aclError aclrtMemcpy(void *dst, size_t destMax, const void *src, size_t count, aclrtMemcpyKind kind) {
|
|
||||||
return g_acl_memory->aclrtMemcpy(dst, destMax, src, count, kind);
|
|
||||||
}
|
|
||||||
|
|
||||||
acldvppPicDesc *acldvppCreatePicDesc() { return g_acl_dvpp_pic_desc->acldvppCreatePicDesc(); }
|
|
||||||
aclError acldvppDestroyPicDesc(acldvppPicDesc *picDesc) { return g_acl_dvpp_pic_desc->acldvppDestroyPicDesc(picDesc); }
|
|
||||||
|
|
||||||
aclError acldvppSetPicDescSize(acldvppPicDesc *picDesc, uint32_t size) {
|
|
||||||
return g_acl_dvpp_pic_desc->acldvppSetPicDescSize(picDesc, size);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppSetPicDescFormat(acldvppPicDesc *picDesc, acldvppPixelFormat format) {
|
|
||||||
return g_acl_dvpp_pic_desc->acldvppSetPicDescFormat(picDesc, format);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppSetPicDescWidth(acldvppPicDesc *picDesc, uint32_t width) {
|
|
||||||
return g_acl_dvpp_pic_desc->acldvppSetPicDescWidth(picDesc, width);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppSetPicDescHeight(acldvppPicDesc *picDesc, uint32_t height) {
|
|
||||||
return g_acl_dvpp_pic_desc->acldvppSetPicDescHeight(picDesc, height);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppSetPicDescData(acldvppPicDesc *picDesc, void *dataDev) {
|
|
||||||
return g_acl_dvpp_pic_desc->acldvppSetPicDescData(picDesc, dataDev);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppSetPicDescWidthStride(acldvppPicDesc *picDesc, uint32_t widthStride) {
|
|
||||||
return g_acl_dvpp_pic_desc->acldvppSetPicDescWidthStride(picDesc, widthStride);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppSetPicDescHeightStride(acldvppPicDesc *picDesc, uint32_t heightStride) {
|
|
||||||
return g_acl_dvpp_pic_desc->acldvppSetPicDescHeightStride(picDesc, heightStride);
|
|
||||||
}
|
|
||||||
|
|
||||||
acldvppRoiConfig *acldvppCreateRoiConfig(uint32_t left, uint32_t right, uint32_t top, uint32_t bottom) {
|
|
||||||
return g_acl_dvpp_roi_config->acldvppCreateRoiConfig(left, right, top, bottom);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppDestroyRoiConfig(acldvppRoiConfig *roiConfig) {
|
|
||||||
return g_acl_dvpp_roi_config->acldvppDestroyRoiConfig(roiConfig);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppSetRoiConfig(acldvppRoiConfig *roiConfig, uint32_t left, uint32_t right, uint32_t top,
|
|
||||||
uint32_t bottom) {
|
|
||||||
return g_acl_dvpp_roi_config->acldvppSetRoiConfig(roiConfig, left, right, top, bottom);
|
|
||||||
}
|
|
||||||
|
|
||||||
acldvppResizeConfig *acldvppCreateResizeConfig() { return g_acl_dvpp_resize_config->acldvppCreateResizeConfig(); }
|
|
||||||
|
|
||||||
aclError acldvppDestroyResizeConfig(acldvppResizeConfig *resizeConfig) {
|
|
||||||
return g_acl_dvpp_resize_config->acldvppDestroyResizeConfig(resizeConfig);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppCreateChannel(acldvppChannelDesc *channelDesc) {
|
|
||||||
return g_acl_dvpp_channel_desc->acldvppCreateChannel(channelDesc);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppDestroyChannel(acldvppChannelDesc *channelDesc) {
|
|
||||||
return g_acl_dvpp_channel_desc->acldvppDestroyChannel(channelDesc);
|
|
||||||
}
|
|
||||||
|
|
||||||
acldvppChannelDesc *acldvppCreateChannelDesc() { return g_acl_dvpp_channel_desc->acldvppCreateChannelDesc(); }
|
|
||||||
|
|
||||||
aclError acldvppDestroyChannelDesc(acldvppChannelDesc *channelDesc) {
|
|
||||||
return g_acl_dvpp_channel_desc->acldvppDestroyChannelDesc(channelDesc);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppVpcResizeAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, acldvppPicDesc *outputDesc,
|
|
||||||
acldvppResizeConfig *resizeConfig, aclrtStream stream) {
|
|
||||||
return g_acl_dvpp_process->acldvppVpcResizeAsync(channelDesc, inputDesc, outputDesc, resizeConfig, stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppVpcCropAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc, acldvppPicDesc *outputDesc,
|
|
||||||
acldvppRoiConfig *cropArea, aclrtStream stream) {
|
|
||||||
return g_acl_dvpp_process->acldvppVpcCropAsync(channelDesc, inputDesc, outputDesc, cropArea, stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppVpcCropAndPasteAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc,
|
|
||||||
acldvppPicDesc *outputDesc, acldvppRoiConfig *cropArea,
|
|
||||||
acldvppRoiConfig *pasteArea, aclrtStream stream) {
|
|
||||||
return g_acl_dvpp_process->acldvppVpcCropAndPasteAsync(channelDesc, inputDesc, outputDesc, cropArea, pasteArea,
|
|
||||||
stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppVpcBatchCropAsync(acldvppChannelDesc *channelDesc, acldvppBatchPicDesc *srcBatchDesc, uint32_t *roiNums,
|
|
||||||
uint32_t size, acldvppBatchPicDesc *dstBatchDesc, acldvppRoiConfig *cropAreas[],
|
|
||||||
aclrtStream stream) {
|
|
||||||
return g_acl_dvpp_process->acldvppVpcBatchCropAsync(channelDesc, srcBatchDesc, roiNums, size, dstBatchDesc, cropAreas,
|
|
||||||
stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppJpegDecodeAsync(acldvppChannelDesc *channelDesc, const void *data, uint32_t size,
|
|
||||||
acldvppPicDesc *outputDesc, aclrtStream stream) {
|
|
||||||
return g_acl_dvpp_process->acldvppJpegDecodeAsync(channelDesc, data, size, outputDesc, stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
// jpeg lib
|
|
||||||
void jpeg_CreateDecompress(j_decompress_ptr cinfo, int version, size_t structsize) {
|
|
||||||
g_acl_jpeg_lib->jpeg_CreateDecompress(cinfo, version, structsize);
|
|
||||||
}
|
|
||||||
|
|
||||||
void jpeg_mem_src(j_decompress_ptr cinfo, const unsigned char *inbuffer, unsigned long insize) {
|
|
||||||
g_acl_jpeg_lib->jpeg_mem_src(cinfo, inbuffer, insize);
|
|
||||||
}
|
|
||||||
|
|
||||||
int jpeg_read_header(j_decompress_ptr cinfo, boolean require_image) {
|
|
||||||
return g_acl_jpeg_lib->jpeg_read_header(cinfo, require_image);
|
|
||||||
}
|
|
||||||
|
|
||||||
void jpeg_destroy_decompress(j_decompress_ptr cinfo) { g_acl_jpeg_lib->jpeg_destroy_decompress(cinfo); }
|
|
||||||
|
|
||||||
struct jpeg_error_mgr *jpeg_std_error(struct jpeg_error_mgr *err) {
|
|
||||||
return err;
|
|
||||||
}
|
|
|
@ -1,857 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_ACL_STUB_H
|
|
||||||
#define MINDSPORE_ACL_STUB_H
|
|
||||||
|
|
||||||
#include "acl/acl_base.h"
|
|
||||||
#include "acl/acl.h"
|
|
||||||
#include "acl/acl_mdl.h"
|
|
||||||
#include "acl/acl_rt.h"
|
|
||||||
#include "acl/ops/acl_dvpp.h"
|
|
||||||
#include <algorithm>
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <map>
|
|
||||||
#include <functional>
|
|
||||||
#include <cstring>
|
|
||||||
#include "jpeglib.h"
|
|
||||||
|
|
||||||
struct aclDataBuffer {
|
|
||||||
void *data = nullptr;
|
|
||||||
size_t size = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct aclmdlDataset {
|
|
||||||
std::vector<aclDataBuffer *> data_buffers;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct aclTensorDesc {};
|
|
||||||
|
|
||||||
struct AclTensorDesc {
|
|
||||||
std::vector<int64_t> dims;
|
|
||||||
aclDataType data_type = ACL_DT_UNDEFINED;
|
|
||||||
size_t size = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct aclmdlDesc {
|
|
||||||
std::vector<AclTensorDesc> inputs;
|
|
||||||
std::vector<AclTensorDesc> outputs;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct acldvppPicDesc {
|
|
||||||
uint32_t size = 0;
|
|
||||||
acldvppPixelFormat format = PIXEL_FORMAT_YUV_400;
|
|
||||||
uint32_t width = 0;
|
|
||||||
uint32_t height = 0;
|
|
||||||
void *dataDev = nullptr;
|
|
||||||
uint32_t widthStride = 0;
|
|
||||||
uint32_t heightStride = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct acldvppRoiConfig {
|
|
||||||
uint32_t left = 0;
|
|
||||||
uint32_t right = 0;
|
|
||||||
uint32_t top = 0;
|
|
||||||
uint32_t bottom = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct acldvppResizeConfig {
|
|
||||||
uint32_t id;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct acldvppChannelDesc {
|
|
||||||
bool channel_valid_flag = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
class AclModel;
|
|
||||||
extern AclModel *g_acl_model;
|
|
||||||
|
|
||||||
template <class Type>
|
|
||||||
aclError AclItemOnDestroy(
|
|
||||||
std::vector<Type> &live, std::vector<Type> &destroy, const Type *destroy_item,
|
|
||||||
std::function<void(Type &list_item)> func_release = [](Type &list_item) {}) {
|
|
||||||
for (auto it = live.begin(); it != live.end(); it++) {
|
|
||||||
if (&(*it) == destroy_item) {
|
|
||||||
func_release(*it);
|
|
||||||
destroy.push_back(*it);
|
|
||||||
live.erase(it);
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class PtType, typename std::enable_if<std::is_pointer<PtType>::value, int>::type = 0>
|
|
||||||
class ResourceBase {
|
|
||||||
public:
|
|
||||||
using Type = typename std::remove_pointer<PtType>::type;
|
|
||||||
ResourceBase() = default;
|
|
||||||
virtual ~ResourceBase() { Clear(); }
|
|
||||||
void Clear() {
|
|
||||||
for (auto item : resource_live_) {
|
|
||||||
delete item;
|
|
||||||
}
|
|
||||||
resource_live_.clear();
|
|
||||||
resource_destroy_.clear();
|
|
||||||
}
|
|
||||||
template <class... Args>
|
|
||||||
Type *OnCreate(Args &&... args) {
|
|
||||||
auto item = new Type(std::forward<Args>(args)...);
|
|
||||||
resource_live_.push_back(item);
|
|
||||||
return item;
|
|
||||||
}
|
|
||||||
aclError OnDestroy(
|
|
||||||
const Type *item, std::function<void(Type &list_item)> func_release = [](Type &list_item) {}) {
|
|
||||||
auto it = std::find(resource_live_.begin(), resource_live_.end(), item);
|
|
||||||
if (it == resource_live_.end()) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
func_release(**it); // Type&
|
|
||||||
resource_destroy_.push_back(*it); // Type*
|
|
||||||
resource_live_.erase(it);
|
|
||||||
delete item;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
size_t LiveSize() const { return resource_live_.size(); }
|
|
||||||
bool Check() const { return resource_live_.empty(); }
|
|
||||||
std::vector<Type *> resource_live_;
|
|
||||||
std::vector<Type *> resource_destroy_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class AclDataBuffer {
|
|
||||||
public:
|
|
||||||
AclDataBuffer() {}
|
|
||||||
virtual ~AclDataBuffer() { Clear(); }
|
|
||||||
virtual void Clear() { data_buffer_.Clear(); }
|
|
||||||
bool Check() { return data_buffer_.Check(); }
|
|
||||||
|
|
||||||
virtual aclDataBuffer *aclCreateDataBuffer(void *data, size_t size) {
|
|
||||||
aclDataBuffer data_buffer;
|
|
||||||
data_buffer.data = data;
|
|
||||||
data_buffer.size = size;
|
|
||||||
return data_buffer_.OnCreate(data_buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual aclError aclDestroyDataBuffer(const aclDataBuffer *dataBuffer) { return data_buffer_.OnDestroy(dataBuffer); }
|
|
||||||
|
|
||||||
virtual void *aclGetDataBufferAddr(const aclDataBuffer *dataBuffer) {
|
|
||||||
if (dataBuffer == nullptr) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return dataBuffer->data;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual uint32_t aclGetDataBufferSize(const aclDataBuffer *dataBuffer) {
|
|
||||||
if (dataBuffer == nullptr) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
return dataBuffer->size;
|
|
||||||
}
|
|
||||||
ResourceBase<aclDataBuffer *> data_buffer_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class AclDataSet {
|
|
||||||
public:
|
|
||||||
AclDataSet() {}
|
|
||||||
virtual ~AclDataSet() { Clear(); }
|
|
||||||
virtual void Clear() { dataset_.Clear(); }
|
|
||||||
bool Check() { return dataset_.Check(); }
|
|
||||||
|
|
||||||
public:
|
|
||||||
virtual aclmdlDataset *aclmdlCreateDataset() { return dataset_.OnCreate(); }
|
|
||||||
virtual aclError aclmdlDestroyDataset(const aclmdlDataset *dataSet) { return dataset_.OnDestroy(dataSet); }
|
|
||||||
virtual aclError aclmdlAddDatasetBuffer(aclmdlDataset *dataSet, aclDataBuffer *dataBuffer) {
|
|
||||||
if (dataSet == nullptr) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
dataSet->data_buffers.push_back(dataBuffer);
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
virtual size_t aclmdlGetDatasetNumBuffers(const aclmdlDataset *dataSet) {
|
|
||||||
if (dataSet == nullptr) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
return dataSet->data_buffers.size();
|
|
||||||
}
|
|
||||||
virtual aclDataBuffer *aclmdlGetDatasetBuffer(const aclmdlDataset *dataSet, size_t index) {
|
|
||||||
if (dataSet == nullptr || index >= dataSet->data_buffers.size()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return dataSet->data_buffers[index];
|
|
||||||
}
|
|
||||||
ResourceBase<aclmdlDataset *> dataset_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class AclEnv {
|
|
||||||
public:
|
|
||||||
virtual aclError aclInit(const char *configPath) {
|
|
||||||
is_init = true;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
virtual aclError aclFinalize() {
|
|
||||||
is_init = false;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
bool Check() { return is_init == false; }
|
|
||||||
bool is_init = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
class AclModel {
|
|
||||||
public:
|
|
||||||
bool Check() { return model_live_.empty(); }
|
|
||||||
virtual aclError aclmdlLoadFromFile(const char *modelPath, uint32_t *modelId) {
|
|
||||||
model_live_.push_back(cur_max_model_id_);
|
|
||||||
*modelId = cur_max_model_id_;
|
|
||||||
cur_max_model_id_++;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual aclError aclmdlLoadFromMem(const void *model, size_t modelSize, uint32_t *modelId) {
|
|
||||||
return aclmdlLoadFromFile("fake_path", modelId);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual aclError aclmdlLoadFromFileWithMem(const char *modelPath, uint32_t *modelId, void *workPtr, size_t workSize,
|
|
||||||
void *weightPtr, size_t weightSize) {
|
|
||||||
return aclmdlLoadFromFile(modelPath, modelId);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual aclError aclmdlLoadFromMemWithMem(const void *model, size_t modelSize, uint32_t *modelId, void *workPtr,
|
|
||||||
size_t workSize, void *weightPtr, size_t weightSize) {
|
|
||||||
return aclmdlLoadFromMem(model, modelSize, modelId);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual aclError aclmdlExecute(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output) {
|
|
||||||
if (std::find(model_live_.begin(), model_live_.end(), modelId) == model_live_.end()) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (input == nullptr || output == nullptr) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// auto& model_desc = model_live_[modelId];
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual aclError aclmdlExecuteAsync(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output,
|
|
||||||
aclrtStream stream) {
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual aclError aclmdlUnload(uint32_t modelId) {
|
|
||||||
auto it = std::find(model_live_.begin(), model_live_.end(), modelId);
|
|
||||||
if (it == model_live_.end()) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
model_live_.erase(it);
|
|
||||||
model_destroy_.push_back(modelId);
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
uint32_t cur_max_model_id_ = 0;
|
|
||||||
std::vector<uint32_t> model_live_;
|
|
||||||
std::vector<uint32_t> model_destroy_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class AclModelDesc {
|
|
||||||
public:
|
|
||||||
AclModelDesc() {}
|
|
||||||
virtual ~AclModelDesc() { Clear(); }
|
|
||||||
virtual void Clear() { model_desc_.Clear(); }
|
|
||||||
bool Check() { return model_desc_.Check(); }
|
|
||||||
|
|
||||||
public:
|
|
||||||
virtual aclmdlDesc *aclmdlCreateDesc() { return model_desc_.OnCreate(); }
|
|
||||||
aclError aclmdlDestroyDesc(aclmdlDesc *modelDesc) { return model_desc_.OnDestroy(modelDesc); }
|
|
||||||
|
|
||||||
aclError aclmdlGetDesc(aclmdlDesc *modelDesc, uint32_t modelId) {
|
|
||||||
auto &model_live = g_acl_model->model_live_;
|
|
||||||
auto it = std::find(model_live.begin(), model_live.end(), modelId);
|
|
||||||
if (it == model_live.end()) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t aclmdlGetNumInputs(aclmdlDesc *modelDesc) { return modelDesc->inputs.size(); }
|
|
||||||
|
|
||||||
size_t aclmdlGetNumOutputs(aclmdlDesc *modelDesc) { return modelDesc->outputs.size(); }
|
|
||||||
|
|
||||||
size_t aclmdlGetInputSizeByIndex(aclmdlDesc *modelDesc, size_t index) { return modelDesc->inputs[index].size; }
|
|
||||||
|
|
||||||
size_t aclmdlGetOutputSizeByIndex(aclmdlDesc *modelDesc, size_t index) { return modelDesc->outputs[index].size; }
|
|
||||||
|
|
||||||
aclError aclmdlGetInputDims(const aclmdlDesc *modelDesc, size_t index, aclmdlIODims *dims) {
|
|
||||||
auto &input = modelDesc->inputs[index];
|
|
||||||
dims->dimCount = input.dims.size();
|
|
||||||
for (size_t i = 0; i < dims->dimCount; i++) {
|
|
||||||
dims->dims[i] = input.dims[i];
|
|
||||||
}
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclmdlGetOutputDims(const aclmdlDesc *modelDesc, size_t index, aclmdlIODims *dims) {
|
|
||||||
auto &input = modelDesc->outputs[index];
|
|
||||||
dims->dimCount = input.dims.size();
|
|
||||||
for (size_t i = 0; i < dims->dimCount; i++) {
|
|
||||||
dims->dims[i] = input.dims[i];
|
|
||||||
}
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclmdlGetCurOutputDims(const aclmdlDesc *modelDesc, size_t index, aclmdlIODims *dims) {
|
|
||||||
return aclmdlGetOutputDims(modelDesc, index, dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
aclFormat aclmdlGetInputFormat(const aclmdlDesc *modelDesc, size_t index) { return ACL_FORMAT_NCHW; }
|
|
||||||
aclFormat aclmdlGetOutputFormat(const aclmdlDesc *modelDesc, size_t index) { return ACL_FORMAT_NCHW; }
|
|
||||||
|
|
||||||
aclDataType aclmdlGetInputDataType(const aclmdlDesc *modelDesc, size_t index) {
|
|
||||||
return modelDesc->inputs[index].data_type;
|
|
||||||
}
|
|
||||||
|
|
||||||
aclDataType aclmdlGetOutputDataType(const aclmdlDesc *modelDesc, size_t index) {
|
|
||||||
return modelDesc->outputs[index].data_type;
|
|
||||||
}
|
|
||||||
|
|
||||||
ResourceBase<aclmdlDesc *> model_desc_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class AclRunMode {
|
|
||||||
public:
|
|
||||||
virtual aclError aclrtGetRunMode(aclrtRunMode *runMode) {
|
|
||||||
*runMode = aclrtRunMode::ACL_HOST;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class AclDeviceContextStream {
|
|
||||||
public:
|
|
||||||
AclDeviceContextStream() {}
|
|
||||||
~AclDeviceContextStream() { Clear(); }
|
|
||||||
virtual void Clear() {
|
|
||||||
for (auto context : context_live_) {
|
|
||||||
delete (int *)context;
|
|
||||||
}
|
|
||||||
context_live_.clear();
|
|
||||||
context_destroy_.clear();
|
|
||||||
device_id_live_.clear();
|
|
||||||
device_id_destroy_.clear();
|
|
||||||
for (auto item : stream_live_) {
|
|
||||||
delete (int *)item;
|
|
||||||
}
|
|
||||||
stream_live_.clear();
|
|
||||||
stream_destroy_.clear();
|
|
||||||
}
|
|
||||||
bool Check() { return context_live_.empty() && device_id_live_.empty() && stream_live_.empty(); }
|
|
||||||
virtual aclError aclrtCreateContext(aclrtContext *context, int32_t deviceId) {
|
|
||||||
context_live_.push_back(new int());
|
|
||||||
*context = context_live_.back();
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
virtual aclError aclrtDestroyContext(aclrtContext context) {
|
|
||||||
for (auto it = context_live_.begin(); it != context_live_.end(); ++it) {
|
|
||||||
if (*it == context) {
|
|
||||||
context_live_.erase(it);
|
|
||||||
context_destroy_.push_back(context);
|
|
||||||
delete (int *)context;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
aclError aclrtSetCurrentContext(aclrtContext context) { return ACL_ERROR_NONE; }
|
|
||||||
aclError aclrtGetCurrentContext(aclrtContext *context) { return ACL_ERROR_NONE; }
|
|
||||||
virtual aclError aclrtSetDevice(int32_t deviceId) {
|
|
||||||
device_id_live_.push_back(deviceId);
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
virtual aclError aclrtResetDevice(int32_t deviceId) {
|
|
||||||
for (auto it = device_id_live_.begin(); it != device_id_live_.end(); ++it) {
|
|
||||||
if (*it == deviceId) {
|
|
||||||
device_id_live_.erase(it);
|
|
||||||
device_id_destroy_.push_back(deviceId);
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
aclError aclrtGetDevice(int32_t *deviceId) {
|
|
||||||
*deviceId = 0;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
aclError aclrtSynchronizeDevice(void) { return ACL_ERROR_NONE; }
|
|
||||||
aclError aclrtSetTsDevice(aclrtTsId tsId) { return ACL_ERROR_NONE; }
|
|
||||||
aclError aclrtGetDeviceCount(uint32_t *count) {
|
|
||||||
*count = 1;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
virtual aclError aclrtCreateStream(aclrtStream *stream) {
|
|
||||||
stream_live_.push_back(new int());
|
|
||||||
*stream = stream_live_.back();
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
virtual aclError aclrtDestroyStream(aclrtStream stream) {
|
|
||||||
for (auto it = stream_live_.begin(); it != context_live_.end(); ++it) {
|
|
||||||
if (*it == stream) {
|
|
||||||
stream_live_.erase(it);
|
|
||||||
stream_destroy_.push_back(stream);
|
|
||||||
delete (int *)stream;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
aclError aclrtSynchronizeStream(aclrtStream stream) {
|
|
||||||
for (auto it = stream_live_.begin(); it != context_live_.end(); ++it) {
|
|
||||||
if (*it == stream) {
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
std::vector<int32_t> device_id_live_;
|
|
||||||
std::vector<int32_t> device_id_destroy_;
|
|
||||||
std::vector<aclrtContext> context_live_;
|
|
||||||
std::vector<aclrtContext> context_destroy_;
|
|
||||||
std::vector<aclrtStream> stream_live_;
|
|
||||||
std::vector<aclrtStream> stream_destroy_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class AclMemory {
|
|
||||||
public:
|
|
||||||
AclMemory() {}
|
|
||||||
~AclMemory() { Clear(); }
|
|
||||||
void Clear() {
|
|
||||||
for (auto item : device_buffer_live_) {
|
|
||||||
delete[] item;
|
|
||||||
}
|
|
||||||
for (auto item : host_buffer_live_) {
|
|
||||||
delete[] item;
|
|
||||||
}
|
|
||||||
for (auto item : dvpp_buffer_live_) {
|
|
||||||
delete[] item;
|
|
||||||
}
|
|
||||||
device_buffer_live_.clear();
|
|
||||||
device_buffer_destroy_.clear();
|
|
||||||
host_buffer_live_.clear();
|
|
||||||
host_buffer_destroy_.clear();
|
|
||||||
dvpp_buffer_live_.clear();
|
|
||||||
dvpp_buffer_destroy_.clear();
|
|
||||||
}
|
|
||||||
bool Check() { return device_buffer_live_.empty() && host_buffer_live_.empty() && dvpp_buffer_live_.empty(); }
|
|
||||||
virtual aclError aclrtMalloc(void **devPtr, size_t size, aclrtMemMallocPolicy policy) {
|
|
||||||
auto buffer = new uint8_t[size];
|
|
||||||
*devPtr = buffer;
|
|
||||||
device_buffer_live_.push_back(buffer);
|
|
||||||
memory_len_[buffer] = size;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
aclError aclrtFree(void *devPtr) {
|
|
||||||
auto it = std::find(device_buffer_live_.begin(), device_buffer_live_.end(), devPtr);
|
|
||||||
if (it != device_buffer_live_.end()) {
|
|
||||||
delete[](*it);
|
|
||||||
device_buffer_live_.erase(it);
|
|
||||||
device_buffer_destroy_.push_back(*it);
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual aclError aclrtMallocHost(void **hostPtr, size_t size) {
|
|
||||||
auto buffer = new uint8_t[size];
|
|
||||||
*hostPtr = buffer;
|
|
||||||
host_buffer_live_.push_back(buffer);
|
|
||||||
memory_len_[buffer] = size;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclrtFreeHost(void *hostPtr) {
|
|
||||||
auto it = std::find(host_buffer_live_.begin(), host_buffer_live_.end(), hostPtr);
|
|
||||||
if (it != host_buffer_live_.end()) {
|
|
||||||
delete[](*it);
|
|
||||||
host_buffer_live_.erase(it);
|
|
||||||
host_buffer_destroy_.push_back(*it);
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError aclrtMemcpy(void *dst, size_t destMax, const void *src, size_t count, aclrtMemcpyKind kind) {
|
|
||||||
auto is_device_memory = [this](const void *memory, uint32_t use_size) {
|
|
||||||
for (auto it = device_buffer_live_.begin(); it != device_buffer_live_.end(); it++) {
|
|
||||||
auto size = memory_len_[*it];
|
|
||||||
if (memory >= *it && static_cast<const uint8_t *>(memory) + use_size <= (*it) + size) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (auto it = dvpp_buffer_live_.begin(); it != dvpp_buffer_live_.end(); it++) {
|
|
||||||
auto size = memory_len_[*it];
|
|
||||||
if (memory >= *it && static_cast<const uint8_t *>(memory) + use_size <= (*it) + size) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
if (kind == ACL_MEMCPY_HOST_TO_HOST) {
|
|
||||||
if (is_device_memory(dst, destMax) || is_device_memory(src, count)) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
} else if (kind == ACL_MEMCPY_HOST_TO_DEVICE) {
|
|
||||||
if (!is_device_memory(dst, destMax) || is_device_memory(src, count)) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
} else if (kind == ACL_MEMCPY_DEVICE_TO_HOST) {
|
|
||||||
if (is_device_memory(dst, destMax) || !is_device_memory(src, count)) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
} else if (kind == ACL_MEMCPY_DEVICE_TO_DEVICE) {
|
|
||||||
if (!is_device_memory(dst, destMax) || !is_device_memory(src, count)) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
memcpy(dst, src, count);
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual aclError acldvppMalloc(void **devPtr, size_t size) {
|
|
||||||
auto buffer = new uint8_t[size];
|
|
||||||
*devPtr = buffer;
|
|
||||||
dvpp_buffer_live_.push_back(buffer);
|
|
||||||
memory_len_[buffer] = size;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
aclError acldvppFree(void *devPtr) {
|
|
||||||
auto it = std::find(dvpp_buffer_live_.begin(), dvpp_buffer_live_.end(), devPtr);
|
|
||||||
if (it != dvpp_buffer_live_.end()) {
|
|
||||||
delete[](*it);
|
|
||||||
dvpp_buffer_live_.erase(it);
|
|
||||||
dvpp_buffer_destroy_.push_back(*it);
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<uint8_t *> device_buffer_live_;
|
|
||||||
std::vector<uint8_t *> device_buffer_destroy_;
|
|
||||||
std::vector<uint8_t *> host_buffer_live_;
|
|
||||||
std::vector<uint8_t *> host_buffer_destroy_;
|
|
||||||
std::vector<uint8_t *> dvpp_buffer_live_;
|
|
||||||
std::vector<uint8_t *> dvpp_buffer_destroy_;
|
|
||||||
std::map<uint8_t *, uint32_t> memory_len_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class AclDvppPicDesc {
|
|
||||||
public:
|
|
||||||
bool Check() { return pic_desc_.Check(); }
|
|
||||||
acldvppPicDesc *acldvppCreatePicDesc() { return pic_desc_.OnCreate(); }
|
|
||||||
|
|
||||||
aclError acldvppDestroyPicDesc(acldvppPicDesc *picDesc) { return pic_desc_.OnDestroy(picDesc); }
|
|
||||||
|
|
||||||
aclError acldvppSetPicDescSize(acldvppPicDesc *picDesc, uint32_t size) {
|
|
||||||
picDesc->size = size;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppSetPicDescFormat(acldvppPicDesc *picDesc, acldvppPixelFormat format) {
|
|
||||||
picDesc->format = format;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppSetPicDescWidth(acldvppPicDesc *picDesc, uint32_t width) {
|
|
||||||
picDesc->width = width;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppSetPicDescHeight(acldvppPicDesc *picDesc, uint32_t height) {
|
|
||||||
picDesc->height = height;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppSetPicDescData(acldvppPicDesc *picDesc, void *dataDev) {
|
|
||||||
picDesc->dataDev = dataDev;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppSetPicDescWidthStride(acldvppPicDesc *picDesc, uint32_t widthStride) {
|
|
||||||
picDesc->widthStride = widthStride;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppSetPicDescHeightStride(acldvppPicDesc *picDesc, uint32_t heightStride) {
|
|
||||||
picDesc->heightStride = heightStride;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
ResourceBase<acldvppPicDesc *> pic_desc_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class AclDvppRoiConfig {
|
|
||||||
public:
|
|
||||||
bool Check() { return roi_config_.Check(); }
|
|
||||||
acldvppRoiConfig *acldvppCreateRoiConfig(uint32_t left, uint32_t right, uint32_t top, uint32_t bottom) {
|
|
||||||
return roi_config_.OnCreate(acldvppRoiConfig{.left = left, .right = right, .top = top, .bottom = bottom});
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppDestroyRoiConfig(acldvppRoiConfig *roiConfig) { return roi_config_.OnDestroy(roiConfig); }
|
|
||||||
|
|
||||||
aclError acldvppSetRoiConfig(acldvppRoiConfig *roiConfig, uint32_t left, uint32_t right, uint32_t top,
|
|
||||||
uint32_t bottom) {
|
|
||||||
roiConfig->left = left;
|
|
||||||
roiConfig->right = right;
|
|
||||||
roiConfig->top = top;
|
|
||||||
roiConfig->bottom = bottom;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
ResourceBase<acldvppRoiConfig *> roi_config_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class AclDvppResizeConfig {
|
|
||||||
public:
|
|
||||||
bool Check() { return resize_config_.Check(); }
|
|
||||||
acldvppResizeConfig *acldvppCreateResizeConfig() { return resize_config_.OnCreate(acldvppResizeConfig{}); }
|
|
||||||
|
|
||||||
aclError acldvppDestroyResizeConfig(acldvppResizeConfig *resizeConfig) {
|
|
||||||
return resize_config_.OnDestroy(resizeConfig);
|
|
||||||
}
|
|
||||||
ResourceBase<acldvppResizeConfig *> resize_config_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class AclDvppChannelDesc {
|
|
||||||
public:
|
|
||||||
bool Check() { return channel_desc_.Check(); }
|
|
||||||
aclError acldvppCreateChannel(acldvppChannelDesc *channelDesc) {
|
|
||||||
channelDesc->channel_valid_flag = true;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
aclError acldvppDestroyChannel(acldvppChannelDesc *channelDesc) {
|
|
||||||
channelDesc->channel_valid_flag = false;
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
acldvppChannelDesc *acldvppCreateChannelDesc() { return channel_desc_.OnCreate(); }
|
|
||||||
aclError acldvppDestroyChannelDesc(acldvppChannelDesc *channelDesc) {
|
|
||||||
if (channelDesc->channel_valid_flag) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
return channel_desc_.OnDestroy(channelDesc);
|
|
||||||
}
|
|
||||||
ResourceBase<acldvppChannelDesc *> channel_desc_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class AclDvppProcess {
|
|
||||||
public:
|
|
||||||
bool Check() { return true; }
|
|
||||||
virtual aclError acldvppVpcResizeAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc,
|
|
||||||
acldvppPicDesc *outputDesc, acldvppResizeConfig *resizeConfig,
|
|
||||||
aclrtStream stream) {
|
|
||||||
resize_call_times_++;
|
|
||||||
if (channelDesc == nullptr || inputDesc == nullptr || outputDesc == nullptr || resizeConfig == nullptr ||
|
|
||||||
stream == nullptr) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (CheckPicDesc(inputDesc) != ACL_ERROR_NONE) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (CheckPicDesc(outputDesc) != ACL_ERROR_NONE) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual aclError acldvppVpcCropAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc,
|
|
||||||
acldvppPicDesc *outputDesc, acldvppRoiConfig *cropArea, aclrtStream stream) {
|
|
||||||
crop_call_times_++;
|
|
||||||
if (channelDesc == nullptr || inputDesc == nullptr || outputDesc == nullptr || cropArea == nullptr ||
|
|
||||||
stream == nullptr) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (CheckPicDesc(inputDesc) != ACL_ERROR_NONE) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (CheckPicDesc(outputDesc) != ACL_ERROR_NONE) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (CheckCropArea(cropArea) != ACL_ERROR_NONE) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual aclError acldvppVpcCropAndPasteAsync(acldvppChannelDesc *channelDesc, acldvppPicDesc *inputDesc,
|
|
||||||
acldvppPicDesc *outputDesc, acldvppRoiConfig *cropArea,
|
|
||||||
acldvppRoiConfig *pasteArea, aclrtStream stream) {
|
|
||||||
crop_paste_call_times_++;
|
|
||||||
if (channelDesc == nullptr || inputDesc == nullptr || outputDesc == nullptr || cropArea == nullptr ||
|
|
||||||
pasteArea == nullptr || stream == nullptr) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (CheckPicDesc(inputDesc) != ACL_ERROR_NONE) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (CheckPicDesc(outputDesc) != ACL_ERROR_NONE) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (CheckCropArea(cropArea) != ACL_ERROR_NONE) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (CheckCropArea(pasteArea) != ACL_ERROR_NONE) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
aclError acldvppVpcBatchCropAsync(acldvppChannelDesc *channelDesc, acldvppBatchPicDesc *srcBatchDesc,
|
|
||||||
uint32_t *roiNums, uint32_t size, acldvppBatchPicDesc *dstBatchDesc,
|
|
||||||
acldvppRoiConfig *cropAreas[], aclrtStream stream) {
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual aclError acldvppJpegDecodeAsync(acldvppChannelDesc *channelDesc, const void *data, uint32_t size,
|
|
||||||
acldvppPicDesc *outputDesc, aclrtStream stream) {
|
|
||||||
decode_call_times_++;
|
|
||||||
if (channelDesc == nullptr || data == nullptr || size == 0 || outputDesc == nullptr || stream == nullptr) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (outputDesc->widthStride % 128 != 0) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (outputDesc->heightStride % 16 != 0) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (outputDesc->widthStride < 32 || outputDesc->widthStride > 8192) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (outputDesc->heightStride < 32 || outputDesc->heightStride > 8192) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (CheckPicDesc(outputDesc) != ACL_ERROR_NONE) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
aclError CheckCropArea(acldvppRoiConfig *crop_area) {
|
|
||||||
if (crop_area->left % 2 != 0 || crop_area->top % 2 != 0) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (crop_area->right % 2 != 1 || crop_area->bottom % 2 != 1) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
auto crop_width = crop_area->right - crop_area->left + 1;
|
|
||||||
if (crop_width < 10 || crop_width > 4096) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
auto crop_heigth = crop_area->bottom - crop_area->top + 1;
|
|
||||||
if (crop_heigth < 6 || crop_heigth > 4096) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
aclError CheckPicDesc(acldvppPicDesc *pic_desc) {
|
|
||||||
if (pic_desc->width == 0 || pic_desc->height == 0) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (pic_desc->widthStride % 16 != 0 || pic_desc->widthStride < pic_desc->width) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (pic_desc->heightStride % 2 != 0 || pic_desc->heightStride < pic_desc->height) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (pic_desc->widthStride < 32 || pic_desc->widthStride > 4096) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (pic_desc->heightStride < 6 || pic_desc->heightStride > 4096) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (pic_desc->dataDev == nullptr) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
auto size = pic_desc->size;
|
|
||||||
auto ele_cnt = pic_desc->widthStride * pic_desc->heightStride;
|
|
||||||
switch (pic_desc->format) {
|
|
||||||
case PIXEL_FORMAT_YUV_SEMIPLANAR_420:
|
|
||||||
case PIXEL_FORMAT_YVU_SEMIPLANAR_420:
|
|
||||||
if (ele_cnt * 3 / 2 != size) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case PIXEL_FORMAT_YUV_SEMIPLANAR_422:
|
|
||||||
case PIXEL_FORMAT_YVU_SEMIPLANAR_422:
|
|
||||||
if (ele_cnt * 2 != size) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case PIXEL_FORMAT_YUV_SEMIPLANAR_444:
|
|
||||||
case PIXEL_FORMAT_YVU_SEMIPLANAR_444:
|
|
||||||
if (ele_cnt * 3 != size) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
return ACL_ERROR_NONE;
|
|
||||||
}
|
|
||||||
uint32_t decode_call_times_ = 0;
|
|
||||||
uint32_t resize_call_times_ = 0;
|
|
||||||
uint32_t crop_call_times_ = 0;
|
|
||||||
uint32_t crop_paste_call_times_ = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
class AclJpegLib {
|
|
||||||
public:
|
|
||||||
bool Check() { return jpeg_live_.empty(); }
|
|
||||||
AclJpegLib(uint32_t width, uint32_t height) : image_width_(width), image_height_(height) {}
|
|
||||||
|
|
||||||
void jpeg_CreateDecompress(j_decompress_ptr cinfo, int version, size_t structsize) { jpeg_live_.push_back(cinfo); }
|
|
||||||
void jpeg_mem_src(j_decompress_ptr cinfo, const unsigned char *inbuffer, unsigned long insize) {}
|
|
||||||
int jpeg_read_header(j_decompress_ptr cinfo, boolean require_image) {
|
|
||||||
static JHUFF_TBL tal;
|
|
||||||
cinfo->image_width = image_width_;
|
|
||||||
cinfo->image_height = image_height_;
|
|
||||||
cinfo->jpeg_color_space = color_space_;
|
|
||||||
for (int i = 0; i < NUM_HUFF_TBLS; i++) {
|
|
||||||
cinfo->ac_huff_tbl_ptrs[i] = &tal;
|
|
||||||
cinfo->dc_huff_tbl_ptrs[i] = &tal;
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
void jpeg_destroy_decompress(j_decompress_ptr cinfo) {
|
|
||||||
auto it = std::find(jpeg_live_.begin(), jpeg_live_.end(), cinfo);
|
|
||||||
if (it != jpeg_live_.end()) {
|
|
||||||
jpeg_live_.erase(it);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
uint32_t image_width_;
|
|
||||||
uint32_t image_height_;
|
|
||||||
J_COLOR_SPACE color_space_ = JCS_YCbCr;
|
|
||||||
std::vector<j_decompress_ptr> jpeg_live_;
|
|
||||||
};
|
|
||||||
|
|
||||||
extern AclDataBuffer *g_acl_data_buffer;
|
|
||||||
extern AclEnv *g_acl_env;
|
|
||||||
extern AclDataSet *g_acl_dataset;
|
|
||||||
extern AclModelDesc *g_acl_model_desc;
|
|
||||||
extern AclDeviceContextStream *g_acl_device_context_stream;
|
|
||||||
extern AclMemory *g_acl_memory;
|
|
||||||
extern AclDvppPicDesc *g_acl_dvpp_pic_desc;
|
|
||||||
extern AclDvppRoiConfig *g_acl_dvpp_roi_config;
|
|
||||||
extern AclDvppResizeConfig *g_acl_dvpp_resize_config;
|
|
||||||
extern AclDvppChannelDesc *g_acl_dvpp_channel_desc;
|
|
||||||
extern AclDvppProcess *g_acl_dvpp_process;
|
|
||||||
extern AclRunMode *g_acl_run_mode;
|
|
||||||
extern AclJpegLib *g_acl_jpeg_lib;
|
|
||||||
|
|
||||||
#endif // MINDSPORE_ACL_STUB_H
|
|
|
@ -1 +0,0 @@
|
||||||
../../../../serving/ms_service.proto
|
|
Loading…
Reference in New Issue