serving add acl support, extract common inference interface
This commit is contained in:
parent
bfc18704d5
commit
314208633b
|
@ -51,6 +51,8 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
|||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/flatbuffers/include)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/flatbuffers/include/flatbuffers)
|
||||
|
||||
if (NOT ENABLE_ACL)
|
||||
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/dependency_utils.cmake)
|
||||
find_package(Python3 3.7 COMPONENTS Interpreter Development)
|
||||
if(Python3_FOUND)
|
||||
|
@ -100,8 +102,12 @@ if (ENABLE_TESTCASES)
|
|||
add_subdirectory(tests)
|
||||
endif()
|
||||
|
||||
endif() # NOT ENABLE_ACL
|
||||
|
||||
if (ENABLE_SERVING)
|
||||
add_subdirectory(serving)
|
||||
endif()
|
||||
|
||||
if (NOT ENABLE_ACL)
|
||||
include(cmake/package.cmake)
|
||||
endif() # NOT ENABLE_ACL
|
||||
|
|
17
build.sh
17
build.sh
|
@ -25,7 +25,7 @@ usage()
|
|||
echo "Usage:"
|
||||
echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t on|off] [-g on|off] [-h] [-b ge] [-m infer|train] \\"
|
||||
echo " [-a on|off] [-Q on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|d|cpu] \\"
|
||||
echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I] [-K] [-B on|off] [-E] [-l on|off]"
|
||||
echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I] [-K] [-B on|off] [-w on|off] [-E] [-l on|off]"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
echo " -d Debug mode"
|
||||
|
@ -54,6 +54,7 @@ usage()
|
|||
echo " -I Compile predict, default off"
|
||||
echo " -K Compile with AKG, default on"
|
||||
echo " -s Enable serving module, default off"
|
||||
echo " -w Enable acl module, default off"
|
||||
echo " -B Enable debugger, default off"
|
||||
echo " -E Enable IBVERBS for parameter server, default off"
|
||||
echo " -l Compile with python dependency, default on"
|
||||
|
@ -97,12 +98,13 @@ checkopts()
|
|||
PREDICT_PLATFORM=""
|
||||
ENABLE_AKG="on"
|
||||
ENABLE_SERVING="off"
|
||||
ENABLE_ACL="off"
|
||||
ENABLE_DEBUGGER="off"
|
||||
ENABLE_IBVERBS="off"
|
||||
ENABLE_PYTHON="on"
|
||||
|
||||
# Process the options
|
||||
while getopts 'drvj:c:t:hsb:a:g:p:ie:m:l:I:LRP:Q:D:zM:V:K:sB:E' opt
|
||||
while getopts 'drvj:c:t:hsb:a:g:p:ie:m:l:I:LRP:Q:D:zM:V:K:swB:E' opt
|
||||
do
|
||||
OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]')
|
||||
case "${opt}" in
|
||||
|
@ -256,6 +258,10 @@ checkopts()
|
|||
ENABLE_SERVING="on"
|
||||
echo "enable serving"
|
||||
;;
|
||||
w)
|
||||
ENABLE_ACL="on"
|
||||
echo "enable acl"
|
||||
;;
|
||||
B)
|
||||
check_on_off $OPTARG B
|
||||
ENABLE_DEBUGGER="on"
|
||||
|
@ -348,6 +354,9 @@ build_mindspore()
|
|||
if [[ "X$ENABLE_SERVING" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_SERVING=ON"
|
||||
fi
|
||||
if [[ "X$ENABLE_ACL" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_ACL=ON"
|
||||
fi
|
||||
if [[ "X$ENABLE_DEBUGGER" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_DEBUGGER=ON"
|
||||
fi
|
||||
|
@ -362,7 +371,11 @@ build_mindspore()
|
|||
if [[ -n "$VERBOSE" ]]; then
|
||||
CMAKE_VERBOSE="--verbose"
|
||||
fi
|
||||
if [[ "X$ENABLE_ACL" = "Xon" ]]; then
|
||||
cmake --build . ${CMAKE_VERBOSE} -j$THREAD_NUM
|
||||
else
|
||||
cmake --build . --target package ${CMAKE_VERBOSE} -j$THREAD_NUM
|
||||
fi
|
||||
echo "success to build mindspore project!"
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_INFERENCE_LOG_H_
|
||||
#define MINDSPORE_INFERENCE_LOG_H_
|
||||
|
||||
#include <stdarg.h>
|
||||
#include <stdint.h>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <memory>
|
||||
#include <iostream>
|
||||
|
||||
#ifndef ENABLE_ACL
|
||||
#include "mindspore/ccsrc/utils/log_adapter.h"
|
||||
namespace mindspore::inference {
|
||||
#define MSI_LOG(level) MS_LOG(level)
|
||||
|
||||
#define MSI_LOG_DEBUG MSI_LOG(DEBUG)
|
||||
#define MSI_LOG_INFO MSI_LOG(INFO)
|
||||
#define MSI_LOG_WARNING MSI_LOG(WARNING)
|
||||
#define MSI_LOG_ERROR MSI_LOG(ERROR)
|
||||
|
||||
#define MSI_ASSERT(item) MS_ASSERT(item)
|
||||
} // namespace mindspore::inference
|
||||
|
||||
#else // ENABLE_ACL
|
||||
#include "acl/acl.h"
|
||||
namespace mindspore::inference {
|
||||
|
||||
class LogStream {
|
||||
public:
|
||||
LogStream() { sstream_ = std::make_shared<std::stringstream>(); }
|
||||
~LogStream() = default;
|
||||
|
||||
template <typename T>
|
||||
LogStream &operator<<(const T &val) noexcept {
|
||||
(*sstream_) << val;
|
||||
return *this;
|
||||
}
|
||||
|
||||
LogStream &operator<<(std::ostream &func(std::ostream &os)) noexcept {
|
||||
(*sstream_) << func;
|
||||
return *this;
|
||||
}
|
||||
|
||||
friend class LogWriter;
|
||||
|
||||
private:
|
||||
std::shared_ptr<std::stringstream> sstream_;
|
||||
};
|
||||
|
||||
template <class T, typename std::enable_if<std::is_enum<T>::value, int>::type = 0>
|
||||
constexpr std::ostream &operator<<(std::ostream &stream, const T &value) {
|
||||
return stream << static_cast<typename std::underlying_type<T>::type>(value);
|
||||
}
|
||||
|
||||
class LogWriter {
|
||||
public:
|
||||
LogWriter(const char *file, int line, const char *func, aclLogLevel log_level)
|
||||
: file_(file), line_(line), func_(func), log_level_(log_level) {}
|
||||
~LogWriter() = default;
|
||||
|
||||
void operator<(const LogStream &stream) const noexcept __attribute__((visibility("default"))) {
|
||||
std::ostringstream msg;
|
||||
msg << stream.sstream_->rdbuf();
|
||||
OutputLog(msg);
|
||||
}
|
||||
|
||||
private:
|
||||
void OutputLog(const std::ostringstream &msg) const { aclAppLog(log_level_, func_, file_, line_, msg.str().c_str()); }
|
||||
|
||||
const char *file_;
|
||||
int line_;
|
||||
const char *func_;
|
||||
aclLogLevel log_level_;
|
||||
};
|
||||
|
||||
#define MSILOG_IF(level) inference::LogWriter(__FILE__, __LINE__, __FUNCTION__, ACL_##level) < inference::LogStream()
|
||||
|
||||
#define MSI_LOG(level) MSI_LOG_##level
|
||||
|
||||
#define MSI_LOG_DEBUG MSILOG_IF(DEBUG)
|
||||
#define MSI_LOG_INFO MSILOG_IF(INFO)
|
||||
#define MSI_LOG_WARNING MSILOG_IF(WARNING)
|
||||
#define MSI_LOG_ERROR MSILOG_IF(ERROR)
|
||||
|
||||
#define MSI_ASSERT(item)
|
||||
|
||||
} // namespace mindspore::inference
|
||||
|
||||
#endif // ENABLE_ACL
|
||||
|
||||
#endif // MINDSPORE_INFERENCE_LOG_H_
|
|
@ -0,0 +1,191 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_INCLUDE_INFER_TENSOR_H_
|
||||
#define MINDSPORE_INCLUDE_INFER_TENSOR_H_
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <map>
|
||||
#include <functional>
|
||||
|
||||
#include "securec/include/securec.h"
|
||||
#include "include/infer_log.h"
|
||||
|
||||
namespace mindspore {
|
||||
#define MS_API __attribute__((visibility("default")))
|
||||
namespace inference {
|
||||
|
||||
enum DataType {
|
||||
kMSI_Unknown = 0,
|
||||
kMSI_Bool = 1,
|
||||
kMSI_Int8 = 2,
|
||||
kMSI_Int16 = 3,
|
||||
kMSI_Int32 = 4,
|
||||
kMSI_Int64 = 5,
|
||||
kMSI_Uint8 = 6,
|
||||
kMSI_Uint16 = 7,
|
||||
kMSI_Uint32 = 8,
|
||||
kMSI_Uint64 = 9,
|
||||
kMSI_Float16 = 10,
|
||||
kMSI_Float32 = 11,
|
||||
kMSI_Float64 = 12,
|
||||
};
|
||||
|
||||
class InferTensorBase {
|
||||
public:
|
||||
InferTensorBase() = default;
|
||||
virtual ~InferTensorBase() = default;
|
||||
|
||||
virtual DataType data_type() const = 0;
|
||||
virtual void set_data_type(DataType type) = 0;
|
||||
virtual std::vector<int64_t> shape() const = 0;
|
||||
virtual void set_shape(const std::vector<int64_t> &shape) = 0;
|
||||
virtual const void *data() const = 0;
|
||||
virtual size_t data_size() const = 0;
|
||||
virtual bool resize_data(size_t data_len) = 0;
|
||||
virtual void *mutable_data() = 0;
|
||||
|
||||
bool set_data(const void *data, size_t data_len) {
|
||||
resize_data(data_len);
|
||||
if (mutable_data() == nullptr) {
|
||||
MSI_LOG_ERROR << "set data failed, data len " << data_len;
|
||||
return false;
|
||||
}
|
||||
if (data_size() != data_len) {
|
||||
MSI_LOG_ERROR << "set data failed, tensor current data size " << data_size() << " not match data len "
|
||||
<< data_len;
|
||||
return false;
|
||||
}
|
||||
if (data_len == 0) {
|
||||
return true;
|
||||
}
|
||||
memcpy_s(mutable_data(), data_size(), data, data_len);
|
||||
return true;
|
||||
}
|
||||
|
||||
int64_t ElementNum() const {
|
||||
std::vector<int64_t> shapex = shape();
|
||||
return std::accumulate(shapex.begin(), shapex.end(), 1LL, std::multiplies<int64_t>());
|
||||
}
|
||||
|
||||
int GetTypeSize(DataType type) const {
|
||||
const std::map<DataType, size_t> type_size_map{
|
||||
{kMSI_Bool, sizeof(bool)}, {kMSI_Float64, sizeof(double)}, {kMSI_Int8, sizeof(int8_t)},
|
||||
{kMSI_Uint8, sizeof(uint8_t)}, {kMSI_Int16, sizeof(int16_t)}, {kMSI_Uint16, sizeof(uint16_t)},
|
||||
{kMSI_Int32, sizeof(int32_t)}, {kMSI_Uint32, sizeof(uint32_t)}, {kMSI_Int64, sizeof(int64_t)},
|
||||
{kMSI_Uint64, sizeof(uint64_t)}, {kMSI_Float16, sizeof(uint16_t)}, {kMSI_Float32, sizeof(float)},
|
||||
};
|
||||
auto it = type_size_map.find(type);
|
||||
if (it != type_size_map.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
class InferTensor : public InferTensorBase {
|
||||
public:
|
||||
DataType type_;
|
||||
std::vector<int64_t> shape_;
|
||||
std::vector<uint8_t> data_;
|
||||
|
||||
public:
|
||||
InferTensor() = default;
|
||||
InferTensor(DataType type, std::vector<int64_t> shape, const void *data, size_t data_len) {
|
||||
set_data_type(type);
|
||||
set_shape(shape);
|
||||
set_data(data, data_len);
|
||||
}
|
||||
|
||||
void set_data_type(DataType type) override { type_ = type; }
|
||||
DataType data_type() const override { return type_; }
|
||||
|
||||
void set_shape(const std::vector<int64_t> &shape) override { shape_ = shape; }
|
||||
std::vector<int64_t> shape() const override { return shape_; }
|
||||
|
||||
const void *data() const override { return data_.data(); }
|
||||
size_t data_size() const override { return data_.size(); }
|
||||
|
||||
bool resize_data(size_t data_len) override {
|
||||
data_.resize(data_len);
|
||||
return true;
|
||||
}
|
||||
void *mutable_data() override { return data_.data(); }
|
||||
};
|
||||
|
||||
class RequestBase {
|
||||
public:
|
||||
virtual size_t size() const = 0;
|
||||
virtual const InferTensorBase *operator[](size_t index) const = 0;
|
||||
};
|
||||
|
||||
class ReplyBase {
|
||||
public:
|
||||
virtual size_t size() const = 0;
|
||||
virtual InferTensorBase *operator[](size_t index) = 0;
|
||||
virtual const InferTensorBase *operator[](size_t index) const = 0;
|
||||
virtual InferTensorBase *add() = 0;
|
||||
virtual void clear() = 0;
|
||||
};
|
||||
|
||||
class VectorInferTensorWrapReply : public ReplyBase {
|
||||
public:
|
||||
explicit VectorInferTensorWrapReply(std::vector<InferTensor> &tensor_list) : tensor_list_(tensor_list) {}
|
||||
|
||||
size_t size() const { return tensor_list_.size(); }
|
||||
InferTensorBase *operator[](size_t index) {
|
||||
if (index >= tensor_list_.size()) {
|
||||
MSI_LOG_ERROR << "visit invalid index " << index << " total size " << tensor_list_.size();
|
||||
return nullptr;
|
||||
}
|
||||
return &(tensor_list_[index]);
|
||||
}
|
||||
const InferTensorBase *operator[](size_t index) const {
|
||||
if (index >= tensor_list_.size()) {
|
||||
MSI_LOG_ERROR << "visit invalid index " << index << " total size " << tensor_list_.size();
|
||||
return nullptr;
|
||||
}
|
||||
return &(tensor_list_[index]);
|
||||
}
|
||||
InferTensorBase *add() {
|
||||
tensor_list_.push_back(InferTensor());
|
||||
return &(tensor_list_.back());
|
||||
}
|
||||
void clear() { tensor_list_.clear(); }
|
||||
std::vector<InferTensor> &tensor_list_;
|
||||
};
|
||||
|
||||
class VectorInferTensorWrapRequest : public RequestBase {
|
||||
public:
|
||||
explicit VectorInferTensorWrapRequest(const std::vector<InferTensor> &tensor_list) : tensor_list_(tensor_list) {}
|
||||
|
||||
size_t size() const { return tensor_list_.size(); }
|
||||
const InferTensorBase *operator[](size_t index) const {
|
||||
if (index >= tensor_list_.size()) {
|
||||
MSI_LOG_ERROR << "visit invalid index " << index << " total size " << tensor_list_.size();
|
||||
return nullptr;
|
||||
}
|
||||
return &(tensor_list_[index]);
|
||||
}
|
||||
const std::vector<InferTensor> &tensor_list_;
|
||||
};
|
||||
|
||||
} // namespace inference
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_INFER_TENSOR_H_
|
|
@ -20,28 +20,32 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "include/ms_tensor.h"
|
||||
#include "include/infer_tensor.h"
|
||||
|
||||
namespace mindspore {
|
||||
class FuncGraph;
|
||||
namespace inference {
|
||||
using VectorForMSTensorPtr = std::vector<std::shared_ptr<inference::MSTensor>>;
|
||||
class MS_API MSSession {
|
||||
|
||||
class MS_API InferSession {
|
||||
public:
|
||||
MSSession() = default;
|
||||
InferSession() = default;
|
||||
virtual ~InferSession() = default;
|
||||
virtual bool InitEnv(const std::string &device_type, uint32_t device_id) = 0;
|
||||
virtual bool FinalizeEnv() = 0;
|
||||
virtual bool LoadModelFromFile(const std::string &file_name, uint32_t &model_id) = 0;
|
||||
virtual bool UnloadModel(uint32_t model_id) = 0;
|
||||
// override this method to avoid request/reply data copy
|
||||
virtual bool ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) = 0;
|
||||
|
||||
static std::shared_ptr<MSSession> CreateSession(const std::string &device, uint32_t device_id);
|
||||
virtual bool ExecuteModel(uint32_t model_id, const std::vector<InferTensor> &inputs,
|
||||
std::vector<InferTensor> &outputs) {
|
||||
VectorInferTensorWrapRequest request(inputs);
|
||||
VectorInferTensorWrapReply reply(outputs);
|
||||
return ExecuteModel(model_id, request, reply);
|
||||
}
|
||||
|
||||
virtual uint32_t CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) = 0;
|
||||
|
||||
virtual MultiTensor RunGraph(uint32_t graph_id, const VectorForMSTensorPtr &inputs) = 0;
|
||||
|
||||
virtual bool CheckModelInputs(uint32_t graph_id, const VectorForMSTensorPtr &inputs) const = 0;
|
||||
static std::shared_ptr<InferSession> CreateSession(const std::string &device, uint32_t device_id);
|
||||
};
|
||||
|
||||
std::shared_ptr<FuncGraph> MS_API LoadModel(const char *model_buf, size_t size, const std::string &device);
|
||||
|
||||
void MS_API ExitInference();
|
||||
} // namespace inference
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_MS_SESSION_H
|
||||
|
|
|
@ -1,69 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_INCLUDE_MS_TENSOR_H_
|
||||
#define MINDSPORE_INCLUDE_MS_TENSOR_H_
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "mindspore/core/ir/dtype/type_id.h"
|
||||
|
||||
namespace mindspore {
|
||||
#define MS_API __attribute__((visibility("default")))
|
||||
namespace inference {
|
||||
class MS_API MSTensor {
|
||||
public:
|
||||
MSTensor() = default;
|
||||
// brief Create a MSTensor pointer.
|
||||
//
|
||||
// param data_type DataTypeId of tensor to be created.
|
||||
// param shape Shape of tensor to be created.
|
||||
// return MSTensor pointer.
|
||||
static MSTensor *CreateTensor(TypeId data_type, const std::vector<int> &shape);
|
||||
|
||||
~MSTensor() = default;
|
||||
|
||||
virtual TypeId data_type() const = 0;
|
||||
|
||||
virtual TypeId set_data_type(const TypeId data_type) = 0;
|
||||
|
||||
virtual std::vector<int> shape() const = 0;
|
||||
|
||||
virtual size_t set_shape(const std::vector<int> &shape) = 0;
|
||||
|
||||
virtual int DimensionSize(size_t index) const = 0;
|
||||
// brief Get number of element in MSTensor.
|
||||
//
|
||||
// return Number of element in MSTensor.
|
||||
virtual int ElementsNum() const = 0;
|
||||
|
||||
virtual std::size_t hash() const = 0;
|
||||
// brief Get byte size of data in MSTensor.
|
||||
//
|
||||
// return Byte size of data in MSTensor.
|
||||
virtual size_t Size() const = 0;
|
||||
// brief Get pointer of data in MSTensor.
|
||||
//
|
||||
// The data pointer can be used to both write or read data in MSTensor.
|
||||
//
|
||||
// return A pointer points to data in MSTensor.
|
||||
virtual void *MutableData() const = 0;
|
||||
};
|
||||
using MultiTensor = std::vector<std::shared_ptr<inference::MSTensor>>;
|
||||
} // namespace inference
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_MS_TENSOR_H_
|
|
@ -297,7 +297,7 @@ set(LOAD_ONNX_SRC
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/utils/load_onnx/anf_model_parser.cc
|
||||
)
|
||||
add_library(inference SHARED
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/session/session.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/session/infer_session.cc
|
||||
${LOAD_ONNX_SRC}
|
||||
)
|
||||
target_link_libraries(inference PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
|
||||
|
|
|
@ -88,8 +88,7 @@ GraphId AscendInferenceSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
|||
return graph_id;
|
||||
}
|
||||
|
||||
bool AscendInferenceSession::CheckModelInputs(uint32_t graph_id,
|
||||
const std::vector<std::shared_ptr<inference::MSTensor> > &inputs) {
|
||||
bool AscendInferenceSession::CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const {
|
||||
MS_LOG(INFO) << "Start check client inputs, graph id : " << graph_id;
|
||||
auto kernel_graph = GetGraph(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
@ -119,8 +118,7 @@ bool AscendInferenceSession::CheckModelInputs(uint32_t graph_id,
|
|||
return true;
|
||||
}
|
||||
|
||||
bool AscendInferenceSession::CompareInput(const std::shared_ptr<inference::MSTensor> &input,
|
||||
const ParameterPtr ¶meter) {
|
||||
bool AscendInferenceSession::CompareInput(const tensor::TensorPtr &input, const ParameterPtr ¶meter) const {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
// compare dims
|
||||
|
@ -155,7 +153,7 @@ bool AscendInferenceSession::CompareInput(const std::shared_ptr<inference::MSTen
|
|||
return true;
|
||||
}
|
||||
|
||||
std::string AscendInferenceSession::PrintInputShape(std::vector<size_t> shape) {
|
||||
std::string AscendInferenceSession::PrintInputShape(std::vector<size_t> shape) const {
|
||||
string res = "[";
|
||||
for (auto dim : shape) {
|
||||
res += " " + std::to_string(dim);
|
||||
|
|
|
@ -39,9 +39,9 @@ class AscendInferenceSession : public AscendSession {
|
|||
void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
const std::vector<tensor::TensorPtr> &inputs_const) const;
|
||||
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override;
|
||||
bool CheckModelInputs(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) override;
|
||||
bool CompareInput(const std::shared_ptr<inference::MSTensor> &input, const ParameterPtr ¶meter);
|
||||
std::string PrintInputShape(std::vector<size_t> shape);
|
||||
bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const override;
|
||||
bool CompareInput(const tensor::TensorPtr &input, const ParameterPtr ¶meter) const;
|
||||
std::string PrintInputShape(std::vector<size_t> shape) const;
|
||||
};
|
||||
MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession);
|
||||
} // namespace session
|
||||
|
|
|
@ -0,0 +1,362 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/session/infer_session.h"
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "include/inference.h"
|
||||
#include "utils/load_onnx/anf_converter.h"
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "backend/session/session_factory.h"
|
||||
#include "utils/base_ref_utils.h"
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
|
||||
#ifdef ENABLE_D
|
||||
#include "utils/context/ms_context.h"
|
||||
#endif
|
||||
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace mindspore::inference {
|
||||
|
||||
std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &device, uint32_t device_id) {
|
||||
try {
|
||||
auto session = std::make_shared<MSInferSession>();
|
||||
bool ret = session->InitEnv(device, device_id);
|
||||
if (!ret) {
|
||||
return nullptr;
|
||||
}
|
||||
return session;
|
||||
} catch (std::bad_alloc &e) {
|
||||
MS_LOG(ERROR) << "Inference CreatSession failed, failed to alloc memory";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
MSInferSession::MSInferSession() = default;
|
||||
MSInferSession::~MSInferSession() = default;
|
||||
|
||||
std::shared_ptr<std::vector<char>> MSInferSession::ReadFile(const std::string &file) {
|
||||
if (file.empty()) {
|
||||
MS_LOG(ERROR) << "file is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
std::string realPath = file;
|
||||
std::ifstream ifs(realPath);
|
||||
if (!ifs.good()) {
|
||||
MS_LOG(ERROR) << "file: " << realPath << " is not exist";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
MS_LOG(ERROR) << "file: " << realPath << "open failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
std::shared_ptr<std::vector<char>> buf(new (std::nothrow) std::vector<char>(size));
|
||||
if (buf == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc buf failed, file: " << realPath;
|
||||
ifs.close();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(buf->data(), size);
|
||||
ifs.close();
|
||||
|
||||
return buf;
|
||||
}
|
||||
|
||||
bool MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
|
||||
auto graphBuf = ReadFile(file_name);
|
||||
if (graphBuf == nullptr) {
|
||||
MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str();
|
||||
return false;
|
||||
}
|
||||
auto graph = LoadModel(graphBuf->data(), graphBuf->size(), device_type_);
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
|
||||
return false;
|
||||
}
|
||||
bool ret = CompileGraph(graph, model_id);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Compile graph model failed, file name is " << file_name.c_str();
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "Load model from file " << file_name << " success";
|
||||
|
||||
#ifdef ENABLE_D
|
||||
// set d context
|
||||
rtError_t rt_ret = rtCtxGetCurrent(&context_);
|
||||
if (rt_ret != RT_ERROR_NONE || context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "the ascend device context is null";
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSInferSession::UnloadModel(uint32_t model_id) { return true; }
|
||||
|
||||
tensor::TensorPtr ServingTensor2MSTensor(const InferTensorBase &out_tensor) {
|
||||
std::vector<int> shape;
|
||||
for (auto dim : out_tensor.shape()) {
|
||||
shape.push_back(static_cast<int>(dim));
|
||||
}
|
||||
TypeId data_type;
|
||||
const std::map<inference::DataType, TypeId> type2id_map{
|
||||
{inference::kMSI_Unknown, TypeId::kNumberTypeBegin}, {inference::kMSI_Bool, TypeId::kNumberTypeBool},
|
||||
{inference::kMSI_Int8, TypeId::kNumberTypeInt8}, {inference::kMSI_Uint8, TypeId::kNumberTypeUInt8},
|
||||
{inference::kMSI_Int16, TypeId::kNumberTypeInt16}, {inference::kMSI_Uint16, TypeId::kNumberTypeUInt16},
|
||||
{inference::kMSI_Int32, TypeId::kNumberTypeInt32}, {inference::kMSI_Uint32, TypeId::kNumberTypeUInt32},
|
||||
{inference::kMSI_Int64, TypeId::kNumberTypeInt64}, {inference::kMSI_Uint64, TypeId::kNumberTypeUInt64},
|
||||
{inference::kMSI_Float16, TypeId::kNumberTypeFloat16}, {inference::kMSI_Float32, TypeId::kNumberTypeFloat32},
|
||||
{inference::kMSI_Float64, TypeId::kNumberTypeFloat64},
|
||||
};
|
||||
auto it = type2id_map.find(out_tensor.data_type());
|
||||
if (it == type2id_map.end()) {
|
||||
MSI_LOG_WARNING << "undefined MSI data type " << out_tensor.data_type();
|
||||
return nullptr;
|
||||
} else {
|
||||
data_type = it->second;
|
||||
}
|
||||
|
||||
auto ms_tensor = std::make_shared<tensor::Tensor>(data_type, shape);
|
||||
memcpy_s(ms_tensor->data_c(), ms_tensor->Size(), out_tensor.data(), out_tensor.data_size());
|
||||
return ms_tensor;
|
||||
}
|
||||
|
||||
void MSTensor2ServingTensor(tensor::TensorPtr ms_tensor, InferTensorBase &out_tensor) {
|
||||
vector<int64_t> shape;
|
||||
for (auto dim : ms_tensor->shape()) {
|
||||
shape.push_back(dim);
|
||||
}
|
||||
out_tensor.set_shape(shape);
|
||||
|
||||
const std::map<TypeId, inference::DataType> id2type_map{
|
||||
{TypeId::kNumberTypeBegin, inference::kMSI_Unknown}, {TypeId::kNumberTypeBool, inference::kMSI_Bool},
|
||||
{TypeId::kNumberTypeFloat64, inference::kMSI_Float64}, {TypeId::kNumberTypeInt8, inference::kMSI_Int8},
|
||||
{TypeId::kNumberTypeUInt8, inference::kMSI_Uint8}, {TypeId::kNumberTypeInt16, inference::kMSI_Int16},
|
||||
{TypeId::kNumberTypeUInt16, inference::kMSI_Uint16}, {TypeId::kNumberTypeInt32, inference::kMSI_Int32},
|
||||
{TypeId::kNumberTypeUInt32, inference::kMSI_Uint32}, {TypeId::kNumberTypeInt64, inference::kMSI_Int64},
|
||||
{TypeId::kNumberTypeUInt64, inference::kMSI_Uint64}, {TypeId::kNumberTypeFloat16, inference::kMSI_Float16},
|
||||
{TypeId::kNumberTypeFloat32, inference::kMSI_Float32},
|
||||
};
|
||||
auto it = id2type_map.find(ms_tensor->data_type());
|
||||
if (it == id2type_map.end()) {
|
||||
MSI_LOG_WARNING << "undefined MS data type " << ms_tensor->data_type();
|
||||
out_tensor.set_data_type(inference::kMSI_Unknown);
|
||||
} else {
|
||||
out_tensor.set_data_type(it->second);
|
||||
}
|
||||
out_tensor.set_data(ms_tensor->data_c(), ms_tensor->Size());
|
||||
}
|
||||
|
||||
bool MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) {
|
||||
#ifdef ENABLE_D
|
||||
if (context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "rtCtx is nullptr";
|
||||
return false;
|
||||
}
|
||||
rtError_t rt_ret = rtCtxSetCurrent(context_);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "set Ascend rtCtx failed";
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
vector<tensor::TensorPtr> inputs;
|
||||
for (size_t i = 0; i < request.size(); i++) {
|
||||
if (request[i] == nullptr) {
|
||||
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed, input tensor is null, index " << i;
|
||||
return false;
|
||||
}
|
||||
auto input = ServingTensor2MSTensor(*request[i]);
|
||||
if (input == nullptr) {
|
||||
MS_LOG(ERROR) << "Tensor convert failed";
|
||||
return false;
|
||||
}
|
||||
inputs.push_back(input);
|
||||
}
|
||||
if (!CheckModelInputs(model_id, inputs)) {
|
||||
MS_LOG(ERROR) << "Check Model " << model_id << " Inputs Failed";
|
||||
return false;
|
||||
}
|
||||
vector<tensor::TensorPtr> outputs = RunGraph(model_id, inputs);
|
||||
if (outputs.empty()) {
|
||||
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed";
|
||||
return false;
|
||||
}
|
||||
reply.clear();
|
||||
for (const auto &tensor : outputs) {
|
||||
auto out_tensor = reply.add();
|
||||
if (out_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed, add output tensor failed";
|
||||
return false;
|
||||
}
|
||||
MSTensor2ServingTensor(tensor, *out_tensor);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSInferSession::FinalizeEnv() {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
if (ms_context == nullptr) {
|
||||
MS_LOG(ERROR) << "Get Context failed!";
|
||||
return false;
|
||||
}
|
||||
if (!ms_context->CloseTsd()) {
|
||||
MS_LOG(ERROR) << "Inference CloseTsd failed!";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<FuncGraph> MSInferSession::LoadModel(const char *model_buf, size_t size, const std::string &device) {
|
||||
try {
|
||||
auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size);
|
||||
return anf_graph;
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Inference LoadModel failed";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void MSInferSession::RegAllOp() {
|
||||
static std::mutex init_mutex;
|
||||
static bool Initialized = false;
|
||||
|
||||
std::lock_guard<std::mutex> lock(init_mutex);
|
||||
if (Initialized) {
|
||||
return;
|
||||
}
|
||||
Initialized = true;
|
||||
MsContext::GetInstance()->set_execution_mode(kGraphMode);
|
||||
Py_Initialize();
|
||||
auto c_expression = PyImport_ImportModule("mindspore._c_expression");
|
||||
if (c_expression == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to import mindspore._c_expression module.";
|
||||
return;
|
||||
}
|
||||
PyObject *c_expression_dict = PyModule_GetDict(c_expression);
|
||||
|
||||
PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy");
|
||||
if (op_info_loader_class == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get op_info_loader_class from mindspore._c_expression.";
|
||||
return;
|
||||
}
|
||||
PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class);
|
||||
if (op_info_loader == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to create op_info_loader instance.";
|
||||
return;
|
||||
}
|
||||
PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr);
|
||||
if (op_info_loader_ins == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to call op_info_loader instance.";
|
||||
return;
|
||||
}
|
||||
auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr);
|
||||
if (all_ops_info_vector_addr_ul == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to call get_all_ops_addr.";
|
||||
return;
|
||||
}
|
||||
auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul);
|
||||
auto all_ops_info = static_cast<std::vector<kernel::OpInfo *> *>(all_ops_info_vector_addr);
|
||||
for (auto op_info : *all_ops_info) {
|
||||
kernel::OpLib::RegOpInfo(std::shared_ptr<kernel::OpInfo>(op_info));
|
||||
}
|
||||
all_ops_info->clear();
|
||||
delete all_ops_info;
|
||||
Py_DECREF(op_info_loader);
|
||||
Py_DECREF(op_info_loader_class);
|
||||
Py_DECREF(c_expression_dict);
|
||||
Py_DECREF(c_expression);
|
||||
return;
|
||||
}
|
||||
|
||||
bool MSInferSession::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id) {
|
||||
MS_ASSERT(session_impl_ != nullptr);
|
||||
try {
|
||||
auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
|
||||
py::gil_scoped_release gil_release;
|
||||
model_id = graph_id;
|
||||
return true;
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Inference CompileGraph failed";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<tensor::TensorPtr> MSInferSession::RunGraph(uint32_t graph_id,
|
||||
const std::vector<tensor::TensorPtr> &inputs) {
|
||||
try {
|
||||
VectorRef outputs;
|
||||
session_impl_->RunGraph(graph_id, inputs, &outputs);
|
||||
|
||||
return TransformVectorRefToMultiTensor(outputs);
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Inference Rungraph failed";
|
||||
return std::vector<tensor::TensorPtr>();
|
||||
}
|
||||
}
|
||||
|
||||
string MSInferSession::AjustTargetName(const std::string &device) {
|
||||
if (device == kAscendDevice) {
|
||||
return std::string(kAscendDevice) + "Inference";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Only support device Ascend right now";
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
bool MSInferSession::InitEnv(const std::string &device, uint32_t device_id) {
|
||||
RegAllOp();
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
ms_context->set_execution_mode(kGraphMode);
|
||||
ms_context->set_device_id(device_id);
|
||||
auto ajust_device = AjustTargetName(device);
|
||||
if (ajust_device == "") {
|
||||
return false;
|
||||
}
|
||||
ms_context->set_device_target(device);
|
||||
session_impl_ = session::SessionFactory::Get().Create(ajust_device);
|
||||
if (session_impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available.";
|
||||
return false;
|
||||
}
|
||||
session_impl_->Init(device_id);
|
||||
if (ms_context == nullptr) {
|
||||
MS_LOG(ERROR) << "Get Context failed!";
|
||||
return false;
|
||||
}
|
||||
if (!ms_context->OpenTsd()) {
|
||||
MS_LOG(ERROR) << "Session init OpenTsd failed!";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSInferSession::CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const {
|
||||
MS_ASSERT(session_impl_ != nullptr);
|
||||
return session_impl_->CheckModelInputs(graph_id, inputs);
|
||||
}
|
||||
|
||||
} // namespace mindspore::inference
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_SESSION_SESSION_H
|
||||
#define MINDSPORE_CCSRC_SESSION_SESSION_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "ir/anf.h"
|
||||
#include "include/inference.h"
|
||||
|
||||
#ifdef ENABLE_D
|
||||
#include "runtime/context.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace inference {
|
||||
class MSInferSession : public InferSession {
|
||||
public:
|
||||
MSInferSession();
|
||||
~MSInferSession();
|
||||
|
||||
bool InitEnv(const std::string &device_type, uint32_t device_id) override;
|
||||
bool FinalizeEnv() override;
|
||||
bool LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override;
|
||||
bool UnloadModel(uint32_t model_id) override;
|
||||
bool ExecuteModel(uint32_t model_id, const RequestBase &inputs, ReplyBase &outputs) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<session::SessionBasic> session_impl_ = nullptr;
|
||||
std::vector<uint32_t> graph_id_;
|
||||
std::string device_type_;
|
||||
int32_t device_id_;
|
||||
#ifdef ENABLE_D
|
||||
rtContext_t context_ = nullptr;
|
||||
#endif
|
||||
|
||||
std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const std::string &device);
|
||||
std::shared_ptr<std::vector<char>> ReadFile(const std::string &file);
|
||||
static void RegAllOp();
|
||||
string AjustTargetName(const std::string &device);
|
||||
bool CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id);
|
||||
bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const;
|
||||
std::vector<tensor::TensorPtr> RunGraph(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs);
|
||||
};
|
||||
} // namespace inference
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H
|
|
@ -1,214 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "include/inference.h"
|
||||
#include "backend/session/session.h"
|
||||
#include "utils/load_onnx/anf_converter.h"
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "backend/session/session_factory.h"
|
||||
#include "utils/base_ref_utils.h"
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
#ifdef ENABLE_D
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "backend/session/ascend_session.h"
|
||||
#else
|
||||
#include "backend/session/cpu_session.h"
|
||||
#endif
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace mindspore::inference {
|
||||
std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const std::string &device) {
|
||||
try {
|
||||
inference::Session::RegAllOp();
|
||||
auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size);
|
||||
return anf_graph;
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Inference LoadModel failed";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void ExitInference() {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
if (ms_context == nullptr) {
|
||||
MS_LOG(ERROR) << "Get Context failed!";
|
||||
return;
|
||||
}
|
||||
if (!ms_context->CloseTsd()) {
|
||||
MS_LOG(ERROR) << "Inference CloseTsd failed!";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<MSSession> MSSession::CreateSession(const std::string &device, uint32_t device_id) {
|
||||
try {
|
||||
auto session = std::make_shared<inference::Session>();
|
||||
auto ret = session->Init(device, device_id);
|
||||
if (ret != 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return session;
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Inference CreatSession failed";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void Session::RegAllOp() {
|
||||
static std::mutex init_mutex;
|
||||
static bool Initialized = false;
|
||||
|
||||
std::lock_guard<std::mutex> lock(init_mutex);
|
||||
if (Initialized) {
|
||||
return;
|
||||
}
|
||||
Initialized = true;
|
||||
MsContext::GetInstance()->set_execution_mode(kGraphMode);
|
||||
Py_Initialize();
|
||||
auto c_expression = PyImport_ImportModule("mindspore._c_expression");
|
||||
if (c_expression == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to import mindspore._c_expression module.";
|
||||
return;
|
||||
}
|
||||
PyObject *c_expression_dict = PyModule_GetDict(c_expression);
|
||||
|
||||
PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy");
|
||||
if (op_info_loader_class == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get op_info_loader_class from mindspore._c_expression.";
|
||||
return;
|
||||
}
|
||||
PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class);
|
||||
if (op_info_loader == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to create op_info_loader instance.";
|
||||
return;
|
||||
}
|
||||
PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr);
|
||||
if (op_info_loader_ins == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to call op_info_loader instance.";
|
||||
return;
|
||||
}
|
||||
auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr);
|
||||
if (all_ops_info_vector_addr_ul == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to call get_all_ops_addr.";
|
||||
return;
|
||||
}
|
||||
auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul);
|
||||
auto all_ops_info = static_cast<std::vector<kernel::OpInfo *> *>(all_ops_info_vector_addr);
|
||||
for (auto op_info : *all_ops_info) {
|
||||
kernel::OpLib::RegOpInfo(std::shared_ptr<kernel::OpInfo>(op_info));
|
||||
}
|
||||
all_ops_info->clear();
|
||||
delete all_ops_info;
|
||||
Py_DECREF(op_info_loader);
|
||||
Py_DECREF(op_info_loader_class);
|
||||
Py_DECREF(c_expression_dict);
|
||||
Py_DECREF(c_expression);
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t Session::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) {
|
||||
MS_ASSERT(session_impl_ != nullptr);
|
||||
try {
|
||||
auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
|
||||
py::gil_scoped_release gil_release;
|
||||
return graph_id;
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Inference CompileGraph failed";
|
||||
return static_cast<uint32_t>(-1);
|
||||
}
|
||||
}
|
||||
|
||||
MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) {
|
||||
try {
|
||||
std::vector<tensor::TensorPtr> inTensors;
|
||||
inTensors.resize(inputs.size());
|
||||
bool has_error = false;
|
||||
std::transform(inputs.begin(), inputs.end(), inTensors.begin(),
|
||||
[&has_error](const std::shared_ptr<inference::MSTensor> &tensor_ptr) -> tensor::TensorPtr {
|
||||
if (tensor_ptr == nullptr) {
|
||||
MS_LOG(WARNING) << "input MSTensor is nullptr, return nullptr";
|
||||
has_error = true;
|
||||
return nullptr;
|
||||
}
|
||||
auto tensor = static_cast<inference::Tensor *>(tensor_ptr.get());
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Can not cast input MSTensor to tensor";
|
||||
has_error = true;
|
||||
return nullptr;
|
||||
}
|
||||
return tensor->tensor();
|
||||
});
|
||||
if (has_error) {
|
||||
MS_LOG(ERROR) << "Init Tensor failed, returning empty result";
|
||||
std::vector<std::shared_ptr<inference::MSTensor>> multiTensor;
|
||||
return multiTensor;
|
||||
}
|
||||
VectorRef outputs;
|
||||
session_impl_->RunGraph(graph_id, inTensors, &outputs);
|
||||
|
||||
return TransformVectorRefToMultiTensor(outputs);
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Inference Rungraph failed";
|
||||
return MultiTensor();
|
||||
}
|
||||
}
|
||||
namespace {
|
||||
string AjustTargetName(const std::string &device) {
|
||||
if (device == kAscendDevice) {
|
||||
return std::string(kAscendDevice) + "Inference";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Only support device Ascend right now";
|
||||
return "";
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
int Session::Init(const std::string &device, uint32_t device_id) {
|
||||
RegAllOp();
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
ms_context->set_execution_mode(kGraphMode);
|
||||
ms_context->set_device_id(device_id);
|
||||
auto ajust_device = AjustTargetName(device);
|
||||
if (ajust_device == "") {
|
||||
return -1;
|
||||
}
|
||||
ms_context->set_device_target(device);
|
||||
session_impl_ = session::SessionFactory::Get().Create(ajust_device);
|
||||
if (session_impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available.";
|
||||
return -1;
|
||||
}
|
||||
session_impl_->Init(device_id);
|
||||
if (ms_context == nullptr) {
|
||||
MS_LOG(ERROR) << "Get Context failed!";
|
||||
return -1;
|
||||
}
|
||||
if (!ms_context->OpenTsd()) {
|
||||
MS_LOG(ERROR) << "Session init OpenTsd failed!";
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool Session::CheckModelInputs(uint32_t graph_id,
|
||||
const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) const {
|
||||
MS_ASSERT(session_impl_ != nullptr);
|
||||
return session_impl_->CheckModelInputs(graph_id, inputs);
|
||||
}
|
||||
|
||||
Session::Session() = default;
|
||||
} // namespace mindspore::inference
|
|
@ -276,7 +276,7 @@ bool ExistSummaryNode(const KernelGraph *graph) {
|
|||
|
||||
GraphId SessionBasic::graph_sum_ = 0;
|
||||
|
||||
KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) {
|
||||
KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) const {
|
||||
auto it = graphs_.find(graph_id);
|
||||
if (it == graphs_.end()) {
|
||||
MS_LOG(WARNING) << "Can't find graph " << graph_id;
|
||||
|
|
|
@ -106,9 +106,7 @@ class SessionBasic {
|
|||
virtual void GetSummaryNodes(KernelGraph *graph);
|
||||
void AssignParamKey(const KernelGraphPtr &kernel_graph);
|
||||
void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &inputs_const);
|
||||
virtual bool CheckModelInputs(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) {
|
||||
return true;
|
||||
}
|
||||
virtual bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const { return true; }
|
||||
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
// set debugger
|
||||
|
@ -120,7 +118,7 @@ class SessionBasic {
|
|||
|
||||
protected:
|
||||
// Get graph by graph id ,if not exist return null ptr
|
||||
KernelGraphPtr GetGraph(GraphId graph_id);
|
||||
KernelGraphPtr GetGraph(GraphId graph_id) const;
|
||||
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
const std::vector<tensor::TensorPtr> &inputs_const) const;
|
||||
void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
|
||||
|
|
|
@ -17,17 +17,17 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include "utils/base_ref_utils.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "include/infer_tensor.h"
|
||||
#include "ir/tensor.h"
|
||||
|
||||
namespace mindspore {
|
||||
void IterateFindTensor(std::vector<std::shared_ptr<inference::MSTensor>> *msTensors, const VectorRef &ref_list) {
|
||||
|
||||
void IterateFindTensor(std::vector<tensor::TensorPtr> *msTensors, const VectorRef &ref_list) {
|
||||
for (size_t i = 0; i < ref_list.size(); ++i) {
|
||||
if (utils::isa<tensor::TensorPtr>(ref_list[i])) {
|
||||
auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(ref_list[i]);
|
||||
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||
auto tensor = new inference::Tensor(tensor_ptr);
|
||||
msTensors->emplace_back(std::shared_ptr<inference::MSTensor>(tensor));
|
||||
msTensors->emplace_back(tensor_ptr);
|
||||
} else if (utils::isa<VectorRef>(ref_list[i])) {
|
||||
auto ref_iter = utils::cast<VectorRef>(ref_list[i]);
|
||||
IterateFindTensor(msTensors, ref_iter);
|
||||
|
@ -37,19 +37,19 @@ void IterateFindTensor(std::vector<std::shared_ptr<inference::MSTensor>> *msTens
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<inference::MSTensor>> TransformVectorRefToMultiTensor(const VectorRef &base_ref) {
|
||||
std::vector<std::shared_ptr<inference::MSTensor>> msTensors;
|
||||
std::vector<tensor::TensorPtr> TransformVectorRefToMultiTensor(const VectorRef &base_ref) {
|
||||
std::vector<tensor::TensorPtr> msTensors;
|
||||
if (utils::isa<VectorRef>(base_ref)) {
|
||||
auto ref_list = utils::cast<VectorRef>(base_ref);
|
||||
IterateFindTensor(&msTensors, ref_list);
|
||||
} else if (utils::isa<tensor::Tensor>(base_ref)) {
|
||||
auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(base_ref);
|
||||
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||
auto tensor = new inference::Tensor(tensor_ptr);
|
||||
msTensors.emplace_back(std::shared_ptr<inference::MSTensor>(tensor));
|
||||
msTensors.emplace_back(tensor_ptr);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
|
||||
}
|
||||
return msTensors;
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,11 +17,12 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include "utils/base_ref.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "include/infer_tensor.h"
|
||||
#include "ir/tensor.h"
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H
|
||||
#define MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H
|
||||
namespace mindspore {
|
||||
std::vector<std::shared_ptr<inference::MSTensor>> TransformVectorRefToMultiTensor(const VectorRef &base_ref);
|
||||
std::vector<tensor::TensorPtr> TransformVectorRefToMultiTensor(const VectorRef &base_ref);
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H
|
||||
|
|
|
@ -85,68 +85,4 @@ bool Tensor::operator==(const Value &other) const {
|
|||
}
|
||||
}
|
||||
} // namespace tensor
|
||||
|
||||
namespace inference {
|
||||
MSTensor *MSTensor::CreateTensor(TypeId data_type, const std::vector<int> &shape) {
|
||||
return new Tensor(data_type, shape);
|
||||
}
|
||||
|
||||
Tensor::Tensor() { this->tensor_impl_ = std::make_shared<tensor::Tensor>(); }
|
||||
|
||||
Tensor::Tensor(TypeId data_type, const std::vector<int> &shape) {
|
||||
this->tensor_impl_ = std::make_shared<tensor::Tensor>(data_type, shape);
|
||||
}
|
||||
|
||||
Tensor::Tensor(std::shared_ptr<tensor::Tensor> tensor_ptr) { this->tensor_impl_ = std::move(tensor_ptr); }
|
||||
|
||||
TypeId Tensor::data_type() const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_->data_type();
|
||||
}
|
||||
|
||||
TypeId Tensor::set_data_type(TypeId data_type) {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_->set_data_type(data_type);
|
||||
}
|
||||
|
||||
std::vector<int> Tensor::shape() const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_->shape();
|
||||
}
|
||||
|
||||
size_t Tensor::set_shape(const std::vector<int> &shape) {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_->set_shape(shape);
|
||||
}
|
||||
|
||||
int Tensor::DimensionSize(size_t index) const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_->DimensionSize(index);
|
||||
}
|
||||
|
||||
int Tensor::ElementsNum() const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_->ElementsNum();
|
||||
}
|
||||
|
||||
std::size_t Tensor::hash() const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_->hash();
|
||||
}
|
||||
|
||||
std::shared_ptr<tensor::Tensor> Tensor::tensor() const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_;
|
||||
}
|
||||
|
||||
size_t Tensor::Size() const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_->Size();
|
||||
}
|
||||
|
||||
void *Tensor::MutableData() const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_->data();
|
||||
}
|
||||
} // namespace inference
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -56,42 +56,6 @@ class Tensor : public MetaTensor {
|
|||
|
||||
using TensorPtr = std::shared_ptr<Tensor>;
|
||||
} // namespace tensor
|
||||
|
||||
namespace inference {
|
||||
class Tensor : public MSTensor {
|
||||
public:
|
||||
Tensor();
|
||||
|
||||
Tensor(TypeId data_type, const std::vector<int> &shape);
|
||||
|
||||
explicit Tensor(std::shared_ptr<tensor::Tensor> tensor_ptr);
|
||||
|
||||
~Tensor() = default;
|
||||
|
||||
TypeId data_type() const override;
|
||||
|
||||
TypeId set_data_type(const TypeId data_type) override;
|
||||
|
||||
std::vector<int> shape() const override;
|
||||
|
||||
size_t set_shape(const std::vector<int> &shape) override;
|
||||
|
||||
int DimensionSize(size_t index) const override;
|
||||
|
||||
int ElementsNum() const override;
|
||||
|
||||
std::size_t hash() const override;
|
||||
|
||||
std::shared_ptr<tensor::Tensor> tensor() const;
|
||||
|
||||
size_t Size() const override;
|
||||
|
||||
void *MutableData() const override;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<tensor::Tensor> tensor_impl_;
|
||||
};
|
||||
} // namespace inference
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_IR_LITE_TENSOR_H_
|
||||
|
|
|
@ -454,67 +454,4 @@ TypeId Tensor::set_data_type(const TypeId data_type) {
|
|||
return data_type;
|
||||
}
|
||||
} // namespace tensor
|
||||
|
||||
namespace inference {
|
||||
MSTensor *MSTensor::CreateTensor(TypeId data_type, const std::vector<int> &shape) {
|
||||
return new Tensor(data_type, shape);
|
||||
}
|
||||
|
||||
Tensor::Tensor(TypeId data_type, const std::vector<int> &shape) {
|
||||
this->tensor_impl_ = std::make_shared<tensor::Tensor>(data_type, shape);
|
||||
}
|
||||
|
||||
Tensor::Tensor(std::shared_ptr<tensor::Tensor> tensor_ptr) { this->tensor_impl_ = std::move(tensor_ptr); }
|
||||
|
||||
TypeId Tensor::data_type() const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_->data_type();
|
||||
}
|
||||
|
||||
TypeId Tensor::set_data_type(TypeId data_type) {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_->set_data_type(data_type);
|
||||
}
|
||||
|
||||
std::vector<int> Tensor::shape() const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_->shape();
|
||||
}
|
||||
|
||||
size_t Tensor::set_shape(const std::vector<int> &shape) {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_->set_shape(shape);
|
||||
}
|
||||
|
||||
int Tensor::DimensionSize(size_t index) const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_->DimensionSize(index);
|
||||
}
|
||||
|
||||
int Tensor::ElementsNum() const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_->ElementsNum();
|
||||
}
|
||||
|
||||
std::size_t Tensor::hash() const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_->hash();
|
||||
}
|
||||
|
||||
std::shared_ptr<tensor::Tensor> Tensor::tensor() const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_;
|
||||
}
|
||||
|
||||
size_t Tensor::Size() const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_->data().nbytes();
|
||||
}
|
||||
|
||||
void *Tensor::MutableData() const {
|
||||
MS_ASSERT(this->tensor_impl_ != nullptr);
|
||||
return this->tensor_impl_->data_c();
|
||||
}
|
||||
|
||||
} // namespace inference
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
#include "Eigen/Core"
|
||||
#include "ir/device_sync.h"
|
||||
#include "ir/meta_tensor.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
using float16 = Eigen::half;
|
||||
|
@ -237,40 +236,6 @@ class Tensor : public MetaTensor {
|
|||
using TensorPtr = std::shared_ptr<Tensor>;
|
||||
using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;
|
||||
} // namespace tensor
|
||||
|
||||
namespace inference {
|
||||
class Tensor : public MSTensor {
|
||||
public:
|
||||
Tensor(TypeId data_type, const std::vector<int> &shape);
|
||||
|
||||
explicit Tensor(std::shared_ptr<tensor::Tensor> tensor_ptr);
|
||||
|
||||
~Tensor() = default;
|
||||
|
||||
TypeId data_type() const override;
|
||||
|
||||
TypeId set_data_type(const TypeId data_type) override;
|
||||
|
||||
std::vector<int> shape() const override;
|
||||
|
||||
size_t set_shape(const std::vector<int> &shape) override;
|
||||
|
||||
int DimensionSize(size_t index) const override;
|
||||
|
||||
int ElementsNum() const override;
|
||||
|
||||
std::size_t hash() const override;
|
||||
|
||||
std::shared_ptr<tensor::Tensor> tensor() const;
|
||||
|
||||
size_t Size() const override;
|
||||
|
||||
void *MutableData() const override;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<tensor::Tensor> tensor_impl_;
|
||||
};
|
||||
} // namespace inference
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_IR_TENSOR_H_
|
||||
|
|
|
@ -13,19 +13,19 @@ add_library(protobuf::libprotobuf ALIAS protobuf::protobuf)
|
|||
add_executable(protobuf::libprotoc ALIAS protobuf::protoc)
|
||||
|
||||
set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf)
|
||||
if(CMAKE_CROSSCOMPILING)
|
||||
if (CMAKE_CROSSCOMPILING)
|
||||
find_program(_PROTOBUF_PROTOC protoc)
|
||||
else()
|
||||
else ()
|
||||
set(_PROTOBUF_PROTOC $<TARGET_FILE:protobuf::protoc>)
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
# Find gRPC installation
|
||||
# Looks for gRPCConfig.cmake file installed by gRPC's cmake installation.
|
||||
if (EXISTS ${grpc_ROOT}/lib64)
|
||||
set(gRPC_DIR "${grpc_ROOT}/lib64/cmake/grpc")
|
||||
else()
|
||||
else ()
|
||||
set(gRPC_DIR "${grpc_ROOT}/lib/cmake/grpc")
|
||||
endif()
|
||||
endif ()
|
||||
message("serving using grpc_DIR : " ${gPRC_DIR})
|
||||
|
||||
find_package(gRPC CONFIG REQUIRED)
|
||||
|
@ -34,11 +34,11 @@ message(STATUS "Using gRPC ${gRPC_VERSION}")
|
|||
set(_GRPC_GRPCPP gRPC::grpc++)
|
||||
set(_REFLECTION gRPC::grpc++_reflection)
|
||||
|
||||
if(CMAKE_CROSSCOMPILING)
|
||||
if (CMAKE_CROSSCOMPILING)
|
||||
find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin)
|
||||
else()
|
||||
else ()
|
||||
set(_GRPC_CPP_PLUGIN_EXECUTABLE $<TARGET_FILE:gRPC::grpc_cpp_plugin>)
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
# Proto file
|
||||
get_filename_component(hw_proto "ms_service.proto" ABSOLUTE)
|
||||
|
@ -67,11 +67,36 @@ file(GLOB_RECURSE CORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
|
||||
list(APPEND SERVING_SRC "main.cc" ${hw_proto_srcs} ${hw_grpc_srcs} ${CORE_SRC_LIST})
|
||||
|
||||
option(ENABLE_ACL "enable acl" OFF)
|
||||
|
||||
if (ENABLE_ACL)
|
||||
if (DEFINED ENV{ASCEND_CUSTOM_PATH})
|
||||
set(ASCEND_PATH $ENV{ASCEND_CUSTOM_PATH})
|
||||
else ()
|
||||
set(ASCEND_PATH /usr/local/Ascend)
|
||||
endif ()
|
||||
set(ACL_LIB_DIR ${ASCEND_PATH}/acllib/)
|
||||
MESSAGE("acl lib dir " ${ACL_LIB_DIR})
|
||||
|
||||
include_directories(${ACL_LIB_DIR}/include/)
|
||||
file(GLOB_RECURSE ACL_SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "acl/*.cc")
|
||||
list(APPEND SERVING_SRC ${ACL_SESSION_SRC_LIST})
|
||||
endif ()
|
||||
|
||||
include_directories(${CMAKE_BINARY_DIR})
|
||||
add_executable(ms_serving ${SERVING_SRC})
|
||||
target_link_libraries(ms_serving inference mindspore_gvar)
|
||||
|
||||
target_link_libraries(ms_serving ${_REFLECTION} ${_GRPC_GRPCPP} ${_PROTOBUF_LIBPROTOBUF} pthread)
|
||||
if (ENABLE_D)
|
||||
add_compile_definitions(ENABLE_D)
|
||||
target_link_libraries(ms_serving ${RUNTIME_LIB})
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
if (ENABLE_ACL)
|
||||
add_compile_definitions(ENABLE_ACL)
|
||||
set(ALC_LIB_SO ${ACL_LIB_DIR}/lib64/libruntime.so ${ACL_LIB_DIR}/lib64/libascendcl.so
|
||||
${ACL_LIB_DIR}/lib64/libacl_retr.so ${ACL_LIB_DIR}/lib64/libacl_cblas.so)
|
||||
target_link_libraries(ms_serving ${ALC_LIB_SO})
|
||||
else ()
|
||||
target_link_libraries(ms_serving inference mindspore_gvar)
|
||||
endif ()
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "serving/acl/acl_session.h"
|
||||
#include "include/infer_log.h"
|
||||
|
||||
namespace mindspore::inference {
|
||||
|
||||
std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &device, uint32_t device_id) {
|
||||
try {
|
||||
auto session = std::make_shared<AclSession>();
|
||||
auto ret = session->InitEnv(device, device_id);
|
||||
if (!ret) {
|
||||
return nullptr;
|
||||
}
|
||||
return session;
|
||||
} catch (std::exception &e) {
|
||||
MSI_LOG_ERROR << "Inference CreatSession failed";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
bool AclSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
|
||||
return model_process_.LoadModelFromFile(file_name, model_id);
|
||||
}
|
||||
|
||||
bool AclSession::UnloadModel(uint32_t model_id) {
|
||||
model_process_.UnLoad();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AclSession::ExecuteModel(uint32_t model_id, const RequestBase &request,
|
||||
ReplyBase &reply) { // set d context
|
||||
aclError rt_ret = aclrtSetCurrentContext(context_);
|
||||
if (rt_ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "set the ascend device context failed";
|
||||
return false;
|
||||
}
|
||||
return model_process_.Execute(request, reply);
|
||||
}
|
||||
|
||||
bool AclSession::InitEnv(const std::string &device_type, uint32_t device_id) {
|
||||
device_type_ = device_type;
|
||||
device_id_ = device_id;
|
||||
auto ret = aclInit(nullptr);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "Execute aclInit Failed";
|
||||
return false;
|
||||
}
|
||||
MSI_LOG_INFO << "acl init success";
|
||||
|
||||
ret = aclrtSetDevice(device_id_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "acl open device " << device_id_ << " failed";
|
||||
return false;
|
||||
}
|
||||
MSI_LOG_INFO << "open device " << device_id_ << " success";
|
||||
|
||||
ret = aclrtCreateContext(&context_, device_id_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "acl create context failed";
|
||||
return false;
|
||||
}
|
||||
MSI_LOG_INFO << "create context success";
|
||||
|
||||
ret = aclrtCreateStream(&stream_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "acl create stream failed";
|
||||
return false;
|
||||
}
|
||||
MSI_LOG_INFO << "create stream success";
|
||||
|
||||
aclrtRunMode run_mode;
|
||||
ret = aclrtGetRunMode(&run_mode);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "acl get run mode failed";
|
||||
return false;
|
||||
}
|
||||
bool is_device = (run_mode == ACL_DEVICE);
|
||||
model_process_.SetIsDevice(is_device);
|
||||
MSI_LOG_INFO << "get run mode success is device input/output " << is_device;
|
||||
|
||||
MSI_LOG_INFO << "Init acl success, device id " << device_id_;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AclSession::FinalizeEnv() {
|
||||
aclError ret;
|
||||
if (stream_ != nullptr) {
|
||||
ret = aclrtDestroyStream(stream_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "destroy stream failed";
|
||||
}
|
||||
stream_ = nullptr;
|
||||
}
|
||||
MSI_LOG_INFO << "end to destroy stream";
|
||||
if (context_ != nullptr) {
|
||||
ret = aclrtDestroyContext(context_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "destroy context failed";
|
||||
}
|
||||
context_ = nullptr;
|
||||
}
|
||||
MSI_LOG_INFO << "end to destroy context";
|
||||
|
||||
ret = aclrtResetDevice(device_id_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "reset devie " << device_id_ << " failed";
|
||||
}
|
||||
MSI_LOG_INFO << "end to reset device " << device_id_;
|
||||
|
||||
ret = aclFinalize();
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "finalize acl failed";
|
||||
}
|
||||
MSI_LOG_INFO << "end to finalize acl";
|
||||
return true;
|
||||
}
|
||||
|
||||
AclSession::AclSession() = default;
|
||||
} // namespace mindspore::inference
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_H
|
||||
#ifndef MINDSPORE_SERVING_ACL_SESSION_H
|
||||
#define MINDSPORE_SERVING_ACL_SESSION_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
@ -23,31 +23,28 @@
|
|||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "ir/anf.h"
|
||||
#include "include/inference.h"
|
||||
#include "serving/acl/model_process.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace inference {
|
||||
class Session : public MSSession {
|
||||
class AclSession : public InferSession {
|
||||
public:
|
||||
Session();
|
||||
AclSession();
|
||||
|
||||
uint32_t CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) override;
|
||||
|
||||
MultiTensor RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) override;
|
||||
|
||||
bool CheckModelInputs(uint32_t graph_id,
|
||||
const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) const override;
|
||||
|
||||
int Init(const std::string &device, uint32_t device_id);
|
||||
|
||||
static void RegAllOp();
|
||||
bool InitEnv(const std::string &device_type, uint32_t device_id) override;
|
||||
bool FinalizeEnv() override;
|
||||
bool LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override;
|
||||
bool UnloadModel(uint32_t model_id) override;
|
||||
bool ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<session::SessionBasic> session_impl_ = nullptr;
|
||||
std::vector<uint32_t> graph_id_;
|
||||
std::string device_type_;
|
||||
int32_t device_id_;
|
||||
aclrtStream stream_ = nullptr;
|
||||
aclrtContext context_ = nullptr;
|
||||
ModelProcess model_process_;
|
||||
};
|
||||
} // namespace inference
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H
|
||||
#endif // MINDSPORE_SERVING_ACL_SESSION_H
|
|
@ -0,0 +1,340 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "serving/acl/model_process.h"
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "include/infer_log.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace inference {
|
||||
|
||||
bool ModelProcess::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
|
||||
aclError acl_ret = aclmdlLoadFromFile(file_name.c_str(), &model_id);
|
||||
if (acl_ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "Read model file failed, file name is " << file_name;
|
||||
return false;
|
||||
}
|
||||
MSI_LOG_INFO << "Load model success " << file_name;
|
||||
|
||||
model_desc_ = aclmdlCreateDesc();
|
||||
acl_ret = aclmdlGetDesc(model_desc_, model_id);
|
||||
if (acl_ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "Read model desc failed";
|
||||
return false;
|
||||
}
|
||||
bool ret = InitInputsBuffer();
|
||||
if (!ret) {
|
||||
MSI_LOG_ERROR << "Create input buffer failed";
|
||||
return false;
|
||||
}
|
||||
ret = InitOutputsBuffer();
|
||||
if (!ret) {
|
||||
MSI_LOG_ERROR << "Create output buffer failed";
|
||||
return false;
|
||||
}
|
||||
model_id_ = model_id;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ModelProcess::InitInputsBuffer() {
|
||||
aclError ret;
|
||||
inputs_ = aclmdlCreateDataset();
|
||||
if (inputs_ == nullptr) {
|
||||
MSI_LOG_ERROR << "Create input dataset failed";
|
||||
return false;
|
||||
}
|
||||
size_t input_size = aclmdlGetNumInputs(model_desc_);
|
||||
|
||||
for (size_t i = 0; i < input_size; ++i) {
|
||||
auto buffer_size = aclmdlGetInputSizeByIndex(model_desc_, i);
|
||||
void *data_mem_buffer = nullptr;
|
||||
if (!is_run_on_device_) { // need to copy input/output to/from device
|
||||
ret = aclrtMalloc(&data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "Malloc device input buffer faild , input size " << buffer_size;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
aclmdlIODims dims;
|
||||
ret = aclmdlGetInputDims(model_desc_, i, &dims);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "Get input shape failed";
|
||||
return false;
|
||||
}
|
||||
aclDataType dataType = aclmdlGetInputDataType(model_desc_, i);
|
||||
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
|
||||
input_infos_.emplace_back(AclTensorInfo{data_mem_buffer, buffer_size, dataType, shape});
|
||||
}
|
||||
MSI_LOG_INFO << "Create model inputs success";
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ModelProcess::CreateDataBuffer(void *&data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset) {
|
||||
aclError ret;
|
||||
auto free_data_buffer = [this](void *dataMemBuffer) {
|
||||
if (!is_run_on_device_) {
|
||||
aclrtFree(dataMemBuffer);
|
||||
} else {
|
||||
aclrtFreeHost(dataMemBuffer);
|
||||
}
|
||||
};
|
||||
if (!is_run_on_device_) {
|
||||
ret = aclrtMalloc(&data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "Malloc device buffer faild , buffer size " << buffer_size;
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
ret = aclrtMallocHost(&data_mem_buffer, buffer_size);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "Malloc device buffer faild , buffer size " << buffer_size;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
auto data_buffer = aclCreateDataBuffer(data_mem_buffer, buffer_size);
|
||||
if (data_buffer == nullptr) {
|
||||
MSI_LOG_ERROR << "Create Data Buffer failed";
|
||||
free_data_buffer(data_mem_buffer);
|
||||
return false;
|
||||
}
|
||||
ret = aclmdlAddDatasetBuffer(dataset, data_buffer);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "add data buffer failed";
|
||||
free_data_buffer(data_mem_buffer);
|
||||
aclDestroyDataBuffer(data_buffer);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ModelProcess::InitOutputsBuffer() {
|
||||
aclError ret;
|
||||
outputs_ = aclmdlCreateDataset();
|
||||
if (outputs_ == nullptr) {
|
||||
MSI_LOG_ERROR << "Create input dataset failed";
|
||||
return false;
|
||||
}
|
||||
size_t output_size = aclmdlGetNumOutputs(model_desc_);
|
||||
for (size_t i = 0; i < output_size; ++i) {
|
||||
auto buffer_size = aclmdlGetOutputSizeByIndex(model_desc_, i);
|
||||
|
||||
void *data_mem_buffer = nullptr;
|
||||
if (CreateDataBuffer(data_mem_buffer, buffer_size, outputs_) != true) {
|
||||
MSI_LOG_ERROR << "add output data buffer failed, buffer size " << buffer_size;
|
||||
return false;
|
||||
}
|
||||
aclmdlIODims dims;
|
||||
ret = aclmdlGetOutputDims(model_desc_, i, &dims);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "Get input shape failed";
|
||||
return false;
|
||||
}
|
||||
aclDataType dataType = aclmdlGetOutputDataType(model_desc_, i);
|
||||
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
|
||||
output_infos_.emplace_back(AclTensorInfo{data_mem_buffer, buffer_size, dataType, shape});
|
||||
}
|
||||
MSI_LOG_INFO << "Create model output success";
|
||||
return true;
|
||||
}
|
||||
|
||||
void ModelProcess::DestroyInputsDataset() {
|
||||
if (inputs_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(inputs_); i++) {
|
||||
auto dataBuffer = aclmdlGetDatasetBuffer(inputs_, i);
|
||||
aclDestroyDataBuffer(dataBuffer);
|
||||
}
|
||||
aclmdlDestroyDataset(inputs_);
|
||||
inputs_ = nullptr;
|
||||
}
|
||||
|
||||
void ModelProcess::DestroyInputsDataMem() {
|
||||
if (!is_run_on_device_) {
|
||||
for (const auto &item : input_infos_) {
|
||||
aclrtFree(item.device_data);
|
||||
}
|
||||
}
|
||||
input_infos_.clear();
|
||||
}
|
||||
|
||||
void ModelProcess::DestroyInputsBuffer() {
|
||||
DestroyInputsDataset();
|
||||
DestroyInputsDataMem();
|
||||
}
|
||||
|
||||
void ModelProcess::DestroyOutputsBuffer() {
|
||||
if (outputs_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(outputs_); i++) {
|
||||
auto dataBuffer = aclmdlGetDatasetBuffer(outputs_, i);
|
||||
auto data = aclGetDataBufferAddr(dataBuffer);
|
||||
if (!is_run_on_device_) {
|
||||
aclrtFree(data);
|
||||
} else {
|
||||
aclrtFreeHost(data);
|
||||
}
|
||||
aclDestroyDataBuffer(dataBuffer);
|
||||
}
|
||||
aclmdlDestroyDataset(outputs_);
|
||||
outputs_ = nullptr;
|
||||
output_infos_.clear();
|
||||
}
|
||||
|
||||
void ModelProcess::UnLoad() {
|
||||
auto ret = aclmdlUnload(model_id_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "Unload model failed";
|
||||
}
|
||||
if (model_desc_ != nullptr) {
|
||||
aclmdlDestroyDesc(model_desc_);
|
||||
model_desc_ = nullptr;
|
||||
}
|
||||
DestroyInputsBuffer();
|
||||
DestroyOutputsBuffer();
|
||||
MSI_LOG_INFO << "End unload model " << model_id_;
|
||||
}
|
||||
|
||||
bool ModelProcess::CheckAndInitInput(const RequestBase &request) {
|
||||
aclError ret;
|
||||
inputs_ = aclmdlCreateDataset();
|
||||
// check inputs
|
||||
if (request.size() != input_infos_.size()) {
|
||||
MSI_LOG_ERROR << "inputs count not match, required count " << input_infos_.size() << ", given count "
|
||||
<< request.size();
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < input_infos_.size(); i++) {
|
||||
if (request[i] == nullptr) {
|
||||
MSI_LOG_ERROR << "input " << i << " cannot be null";
|
||||
return false;
|
||||
}
|
||||
if (request[i]->data_size() != input_infos_[i].buffer_size) {
|
||||
MSI_LOG_ERROR << "input " << i << " data size not match, required size " << input_infos_[i].buffer_size
|
||||
<< ", given count " << request[i]->data_size();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// copy inputs
|
||||
for (size_t i = 0; i < input_infos_.size(); i++) {
|
||||
void *input_buffer = nullptr;
|
||||
auto &info = input_infos_[i];
|
||||
const void *data = request[i]->data();
|
||||
if (!is_run_on_device_) {
|
||||
ret = aclrtMemcpy(info.device_data, info.buffer_size, data, request[i]->data_size(), ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "memcpy input " << i << " data to device failed, buffer size " << request[i]->data_size();
|
||||
return false;
|
||||
}
|
||||
input_buffer = info.device_data;
|
||||
} else {
|
||||
input_buffer = const_cast<void *>(data);
|
||||
}
|
||||
auto data_buffer = aclCreateDataBuffer(input_buffer, info.buffer_size);
|
||||
if (data_buffer == nullptr) {
|
||||
MSI_LOG_ERROR << "Create Data Buffer failed";
|
||||
return false;
|
||||
}
|
||||
ret = aclmdlAddDatasetBuffer(inputs_, data_buffer);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "add data buffer failed";
|
||||
aclDestroyDataBuffer(data_buffer);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ModelProcess::BuildOutputs(ReplyBase &reply) {
|
||||
aclError ret;
|
||||
// copy outputs
|
||||
reply.clear();
|
||||
|
||||
std::unordered_map<aclDataType, inference::DataType> dataTypeMap = {
|
||||
{ACL_FLOAT16, inference::kMSI_Float16}, {ACL_FLOAT, inference::kMSI_Float32}, {ACL_DOUBLE, inference::kMSI_Float64},
|
||||
{ACL_INT8, inference::kMSI_Int8}, {ACL_INT16, inference::kMSI_Int16}, {ACL_INT32, inference::kMSI_Int32},
|
||||
{ACL_INT64, inference::kMSI_Int64}, {ACL_UINT8, inference::kMSI_Uint8}, {ACL_UINT16, inference::kMSI_Uint16},
|
||||
{ACL_UINT32, inference::kMSI_Uint32}, {ACL_UINT64, inference::kMSI_Uint64}, {ACL_BOOL, inference::kMSI_Bool},
|
||||
};
|
||||
auto trans_to_serving_type = [&dataTypeMap](aclDataType data_type) {
|
||||
auto it = dataTypeMap.find(data_type);
|
||||
if (it == dataTypeMap.end()) {
|
||||
return inference::kMSI_Unknown;
|
||||
} else {
|
||||
return it->second;
|
||||
}
|
||||
};
|
||||
for (size_t i = 0; i < output_infos_.size(); i++) {
|
||||
auto &info = output_infos_[i];
|
||||
auto output = reply.add();
|
||||
if (output == nullptr) {
|
||||
MSI_LOG_ERROR << "add new output failed";
|
||||
return false;
|
||||
}
|
||||
output->set_data_type(trans_to_serving_type(info.data_type));
|
||||
output->set_shape(info.dims);
|
||||
if (!output->resize_data(info.buffer_size)) {
|
||||
MSI_LOG_ERROR << "new output data buffer failed, data size " << info.buffer_size;
|
||||
return false;
|
||||
}
|
||||
if (!is_run_on_device_) {
|
||||
ret = aclrtMemcpy(output->mutable_data(), output->data_size(), info.device_data, info.buffer_size,
|
||||
ACL_MEMCPY_DEVICE_TO_HOST);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "Memcpy output " << i << " to host failed, memory size " << info.buffer_size;
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
ret = aclrtMemcpy(output->mutable_data(), output->data_size(), info.device_data, info.buffer_size,
|
||||
ACL_MEMCPY_HOST_TO_HOST);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "Memcpy output " << i << " to host failed, memory size " << info.buffer_size;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ModelProcess::Execute(const RequestBase &request, ReplyBase &reply) {
|
||||
aclError acl_ret;
|
||||
if (CheckAndInitInput(request) != true) {
|
||||
MSI_LOG_ERROR << "check or init input failed";
|
||||
DestroyInputsDataset();
|
||||
return false;
|
||||
}
|
||||
acl_ret = aclmdlExecute(model_id_, inputs_, outputs_);
|
||||
DestroyInputsDataset();
|
||||
if (acl_ret != ACL_ERROR_NONE) {
|
||||
MSI_LOG_ERROR << "Execute Model Failed";
|
||||
return false;
|
||||
}
|
||||
bool ret = BuildOutputs(reply);
|
||||
if (!ret) {
|
||||
MSI_LOG_ERROR << "Build outputs faield";
|
||||
return false;
|
||||
}
|
||||
MSI_LOG_INFO << "excute model success";
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace inference
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,74 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef INC_MODEL_PROCESS_ACL
|
||||
#define INC_MODEL_PROCESS_ACL
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "acl/acl.h"
|
||||
#include "acl/acl_mdl.h"
|
||||
#include "acl/acl_rt.h"
|
||||
#include "serving/core/util/status.h"
|
||||
#include "include/inference.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace inference {
|
||||
|
||||
struct AclTensorInfo {
|
||||
void *device_data;
|
||||
size_t buffer_size;
|
||||
aclDataType data_type;
|
||||
std::vector<int64_t> dims;
|
||||
};
|
||||
|
||||
class ModelProcess {
|
||||
public:
|
||||
ModelProcess() {}
|
||||
~ModelProcess() {}
|
||||
|
||||
bool LoadModelFromFile(const std::string &file_name, uint32_t &model_id);
|
||||
void UnLoad();
|
||||
|
||||
// override this method to avoid request/reply data copy
|
||||
bool Execute(const RequestBase &request, ReplyBase &reply);
|
||||
|
||||
void SetIsDevice(bool is_device) { is_run_on_device_ = is_device; }
|
||||
|
||||
private:
|
||||
uint32_t model_id_ = 0xffffffff;
|
||||
bool is_run_on_device_ = false;
|
||||
aclmdlDesc *model_desc_ = nullptr;
|
||||
aclmdlDataset *inputs_ = nullptr;
|
||||
aclmdlDataset *outputs_ = nullptr;
|
||||
std::vector<AclTensorInfo> input_infos_;
|
||||
std::vector<AclTensorInfo> output_infos_;
|
||||
|
||||
bool CreateDataBuffer(void *&data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset);
|
||||
bool CheckAndInitInput(const RequestBase &request);
|
||||
bool BuildOutputs(ReplyBase &reply);
|
||||
|
||||
bool InitInputsBuffer();
|
||||
bool InitOutputsBuffer();
|
||||
void DestroyInputsDataset();
|
||||
void DestroyInputsDataMem();
|
||||
void DestroyInputsBuffer();
|
||||
void DestroyOutputsBuffer();
|
||||
};
|
||||
|
||||
} // namespace inference
|
||||
} // namespace mindspore
|
||||
|
||||
#endif
|
|
@ -23,14 +23,14 @@
|
|||
#include <utility>
|
||||
#include <memory>
|
||||
#include <future>
|
||||
#include <chrono>
|
||||
|
||||
#include "mindspore/ccsrc/utils/log_adapter.h"
|
||||
#include "include/infer_log.h"
|
||||
#include "serving/ms_service.grpc.pb.h"
|
||||
#include "core/util/option_parser.h"
|
||||
#include "core/version_control/version_controller.h"
|
||||
#include "mindspore/ccsrc/utils/context/ms_context.h"
|
||||
#include "core/util/file_system_operation.h"
|
||||
#include "graphengine/third_party/fwkacllib/inc/runtime/context.h"
|
||||
#include "core/serving_tensor.h"
|
||||
|
||||
using ms_serving::MSService;
|
||||
using ms_serving::PredictReply;
|
||||
|
@ -38,12 +38,19 @@ using ms_serving::PredictRequest;
|
|||
|
||||
namespace mindspore {
|
||||
namespace serving {
|
||||
using MSTensorPtr = std::shared_ptr<inference::MSTensor>;
|
||||
|
||||
#define MSI_TIME_STAMP_START(name) auto time_start_##name = std::chrono::steady_clock::now();
|
||||
#define MSI_TIME_STAMP_END(name) \
|
||||
{ \
|
||||
auto time_end_##name = std::chrono::steady_clock::now(); \
|
||||
auto time_cost = std::chrono::duration<double, std::milli>(time_end_##name - time_start_##name).count(); \
|
||||
MSI_LOG_INFO << #name " Time Cost " << time_cost << "ms ---------------------"; \
|
||||
}
|
||||
|
||||
Status Session::CreatDeviceSession(const std::string &device, uint32_t device_id) {
|
||||
session_ = inference::MSSession::CreateSession(device, device_id);
|
||||
session_ = inference::InferSession::CreateSession(device, device_id);
|
||||
if (session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Creat Session Failed";
|
||||
MSI_LOG(ERROR) << "Creat Session Failed";
|
||||
return FAILED;
|
||||
}
|
||||
device_type_ = device;
|
||||
|
@ -55,53 +62,56 @@ Session &Session::Instance() {
|
|||
return instance;
|
||||
}
|
||||
|
||||
Status Session::Predict(const std::vector<MSTensorPtr> &inputs, inference::MultiTensor *outputs) {
|
||||
if (last_graph_ == nullptr) {
|
||||
MS_LOG(ERROR) << "the model has not loaded";
|
||||
Status Session::Predict(const PredictRequest &request, PredictReply &reply) {
|
||||
if (!model_loaded_) {
|
||||
MSI_LOG(ERROR) << "the model has not loaded";
|
||||
return FAILED;
|
||||
}
|
||||
if (session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "the inference session has not be initialized";
|
||||
MSI_LOG(ERROR) << "the inference session has not be initialized";
|
||||
return FAILED;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
MS_LOG(INFO) << "run Predict";
|
||||
MSI_LOG(INFO) << "run Predict";
|
||||
|
||||
if (!session_->CheckModelInputs(graph_id_, inputs)) {
|
||||
MS_LOG(ERROR) << "Input error.";
|
||||
ServingRequest serving_request(request);
|
||||
ServingReply serving_reply(reply);
|
||||
|
||||
auto ret = session_->ExecuteModel(graph_id_, serving_request, serving_reply);
|
||||
MSI_LOG(INFO) << "run Predict finished";
|
||||
if (!ret) {
|
||||
MSI_LOG(ERROR) << "execute model return failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
*outputs = session_->RunGraph(graph_id_, inputs);
|
||||
MS_LOG(INFO) << "run Predict finished";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Session::Warmup(const MindSporeModelPtr model) {
|
||||
if (session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "The CreatDeviceSession should be called, before warmup";
|
||||
MSI_LOG(ERROR) << "The CreatDeviceSession should be called, before warmup";
|
||||
return FAILED;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
size_t size = 0;
|
||||
std::string file_name = model->GetModelPath() + '/' + model->GetModelName();
|
||||
char *graphBuf = ReadFile(file_name.c_str(), &size);
|
||||
if (graphBuf == nullptr) {
|
||||
MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str();
|
||||
model_loaded_ = false;
|
||||
MSI_TIME_STAMP_START(LoadModelFromFile)
|
||||
auto ret = session_->LoadModelFromFile(file_name, graph_id_);
|
||||
MSI_TIME_STAMP_END(LoadModelFromFile)
|
||||
if (!ret) {
|
||||
MSI_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
|
||||
return FAILED;
|
||||
}
|
||||
last_graph_ = inference::LoadModel(graphBuf, size, device_type_);
|
||||
if (last_graph_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
|
||||
return FAILED;
|
||||
}
|
||||
graph_id_ = session_->CompileGraph(last_graph_);
|
||||
MS_LOG(INFO) << "Session Warmup finished";
|
||||
model_loaded_ = true;
|
||||
MSI_LOG(INFO) << "Session Warmup finished";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Session::Clear() {
|
||||
session_ = nullptr;
|
||||
if (session_ != nullptr) {
|
||||
session_->UnloadModel(graph_id_);
|
||||
session_->FinalizeEnv();
|
||||
session_ = nullptr;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -109,121 +119,30 @@ namespace {
|
|||
static const uint32_t uint32max = 0x7FFFFFFF;
|
||||
std::promise<void> exit_requested;
|
||||
|
||||
const std::map<ms_serving::DataType, TypeId> type2id_map{
|
||||
{ms_serving::MS_UNKNOWN, TypeId::kNumberTypeBegin}, {ms_serving::MS_BOOL, TypeId::kNumberTypeBool},
|
||||
{ms_serving::MS_INT8, TypeId::kNumberTypeInt8}, {ms_serving::MS_UINT8, TypeId::kNumberTypeUInt8},
|
||||
{ms_serving::MS_INT16, TypeId::kNumberTypeInt16}, {ms_serving::MS_UINT16, TypeId::kNumberTypeUInt16},
|
||||
{ms_serving::MS_INT32, TypeId::kNumberTypeInt32}, {ms_serving::MS_UINT32, TypeId::kNumberTypeUInt32},
|
||||
{ms_serving::MS_INT64, TypeId::kNumberTypeInt64}, {ms_serving::MS_UINT64, TypeId::kNumberTypeUInt64},
|
||||
{ms_serving::MS_FLOAT16, TypeId::kNumberTypeFloat16}, {ms_serving::MS_FLOAT32, TypeId::kNumberTypeFloat32},
|
||||
{ms_serving::MS_FLOAT64, TypeId::kNumberTypeFloat64},
|
||||
};
|
||||
|
||||
const std::map<TypeId, ms_serving::DataType> id2type_map{
|
||||
{TypeId::kNumberTypeBegin, ms_serving::MS_UNKNOWN}, {TypeId::kNumberTypeBool, ms_serving::MS_BOOL},
|
||||
{TypeId::kNumberTypeInt8, ms_serving::MS_INT8}, {TypeId::kNumberTypeUInt8, ms_serving::MS_UINT8},
|
||||
{TypeId::kNumberTypeInt16, ms_serving::MS_INT16}, {TypeId::kNumberTypeUInt16, ms_serving::MS_UINT16},
|
||||
{TypeId::kNumberTypeInt32, ms_serving::MS_INT32}, {TypeId::kNumberTypeUInt32, ms_serving::MS_UINT32},
|
||||
{TypeId::kNumberTypeInt64, ms_serving::MS_INT64}, {TypeId::kNumberTypeUInt64, ms_serving::MS_UINT64},
|
||||
{TypeId::kNumberTypeFloat16, ms_serving::MS_FLOAT16}, {TypeId::kNumberTypeFloat32, ms_serving::MS_FLOAT32},
|
||||
{TypeId::kNumberTypeFloat64, ms_serving::MS_FLOAT64},
|
||||
};
|
||||
const std::map<ms_serving::DataType, size_t> length_map{
|
||||
{ms_serving::MS_UNKNOWN, 0},
|
||||
{ms_serving::MS_BOOL, sizeof(bool)},
|
||||
{ms_serving::MS_INT8, sizeof(int8_t)},
|
||||
{ms_serving::MS_UINT8, sizeof(uint8_t)},
|
||||
{ms_serving::MS_INT16, sizeof(int16_t)},
|
||||
{ms_serving::MS_UINT16, sizeof(uint16_t)},
|
||||
{ms_serving::MS_INT32, sizeof(int32_t)},
|
||||
{ms_serving::MS_UINT32, sizeof(uint32_t)},
|
||||
{ms_serving::MS_INT64, sizeof(int64_t)},
|
||||
{ms_serving::MS_UINT64, sizeof(uint64_t)},
|
||||
{ms_serving::MS_FLOAT16, 2},
|
||||
{ms_serving::MS_FLOAT32, 4},
|
||||
{ms_serving::MS_FLOAT64, 8},
|
||||
};
|
||||
MSTensorPtr ServingTensor2MSTensor(const ms_serving::Tensor &tensor) {
|
||||
std::vector<int> shape;
|
||||
for (auto dim : tensor.tensor_shape().dims()) {
|
||||
shape.push_back(static_cast<int>(dim));
|
||||
}
|
||||
auto iter = type2id_map.find(tensor.tensor_type());
|
||||
if (iter == type2id_map.end()) {
|
||||
MS_LOG(ERROR) << "input tensor type is wrong, type is " << tensor.tensor_type();
|
||||
return nullptr;
|
||||
}
|
||||
TypeId type = iter->second;
|
||||
auto ms_tensor = std::shared_ptr<inference::MSTensor>(inference::MSTensor::CreateTensor(type, shape));
|
||||
memcpy_s(ms_tensor->MutableData(), ms_tensor->Size(), tensor.data().data(), tensor.data().size());
|
||||
return ms_tensor;
|
||||
}
|
||||
|
||||
ms_serving::Tensor MSTensor2ServingTensor(MSTensorPtr ms_tensor) {
|
||||
ms_serving::Tensor tensor;
|
||||
ms_serving::TensorShape shape;
|
||||
for (auto dim : ms_tensor->shape()) {
|
||||
shape.add_dims(dim);
|
||||
}
|
||||
*tensor.mutable_tensor_shape() = shape;
|
||||
auto iter = id2type_map.find(ms_tensor->data_type());
|
||||
if (iter == id2type_map.end()) {
|
||||
MS_LOG(ERROR) << "input tensor type is wrong, type is " << tensor.tensor_type();
|
||||
return tensor;
|
||||
}
|
||||
tensor.set_tensor_type(iter->second);
|
||||
tensor.set_data(ms_tensor->MutableData(), ms_tensor->Size());
|
||||
return tensor;
|
||||
}
|
||||
|
||||
void ClearEnv() {
|
||||
Session::Instance().Clear();
|
||||
inference::ExitInference();
|
||||
// inference::ExitInference();
|
||||
}
|
||||
void HandleSignal(int sig) { exit_requested.set_value(); }
|
||||
|
||||
#ifdef ENABLE_D
|
||||
static rtContext_t g_ctx = nullptr;
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
// Service Implement
|
||||
class MSServiceImpl final : public MSService::Service {
|
||||
grpc::Status Predict(grpc::ServerContext *context, const PredictRequest *request, PredictReply *reply) override {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
#ifdef ENABLE_D
|
||||
if (g_ctx == nullptr) {
|
||||
MS_LOG(ERROR) << "rtCtx is nullptr";
|
||||
return grpc::Status::CANCELLED;
|
||||
}
|
||||
rtError_t rt_ret = rtCtxSetCurrent(g_ctx);
|
||||
if (rt_ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "set Ascend rtCtx failed";
|
||||
}
|
||||
#endif
|
||||
std::vector<MSTensorPtr> inputs;
|
||||
inference::MultiTensor outputs;
|
||||
for (int i = 0; i < request->data_size(); i++) {
|
||||
auto input = ServingTensor2MSTensor(request->data(i));
|
||||
if (input == nullptr) {
|
||||
MS_LOG(ERROR) << "Tensor convert failed";
|
||||
return grpc::Status::CANCELLED;
|
||||
}
|
||||
inputs.push_back(input);
|
||||
}
|
||||
auto res = Session::Instance().Predict(inputs, &outputs);
|
||||
MSI_TIME_STAMP_START(Predict)
|
||||
auto res = Session::Instance().Predict(*request, *reply);
|
||||
MSI_TIME_STAMP_END(Predict)
|
||||
if (res != SUCCESS) {
|
||||
return grpc::Status::CANCELLED;
|
||||
}
|
||||
for (const auto &tensor : outputs) {
|
||||
*reply->add_result() = MSTensor2ServingTensor(tensor);
|
||||
}
|
||||
MS_LOG(INFO) << "Finish call service Eval";
|
||||
MSI_LOG(INFO) << "Finish call service Eval";
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status Test(grpc::ServerContext *context, const PredictRequest *request, PredictReply *reply) override {
|
||||
MS_LOG(INFO) << "TestService call";
|
||||
MSI_LOG(INFO) << "TestService call";
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
std::mutex mutex_;
|
||||
|
@ -242,28 +161,17 @@ Status Server::BuildAndStart() {
|
|||
auto device_id = option_args->device_id;
|
||||
res = Session::Instance().CreatDeviceSession(device_type, device_id);
|
||||
if (res != SUCCESS) {
|
||||
MS_LOG(ERROR) << "creat session failed";
|
||||
MSI_LOG(ERROR) << "creat session failed";
|
||||
ClearEnv();
|
||||
return res;
|
||||
}
|
||||
VersionController version_controller(option_args->poll_model_wait_seconds, model_path, model_name);
|
||||
res = version_controller.Run();
|
||||
if (res != SUCCESS) {
|
||||
MS_LOG(ERROR) << "load model failed";
|
||||
MSI_LOG(ERROR) << "load model failed";
|
||||
ClearEnv();
|
||||
return res;
|
||||
}
|
||||
#ifdef ENABLE_D
|
||||
// set d context
|
||||
rtContext_t ctx = nullptr;
|
||||
rtError_t rt_ret = rtCtxGetCurrent(&ctx);
|
||||
if (rt_ret != RT_ERROR_NONE || ctx == nullptr) {
|
||||
MS_LOG(ERROR) << "the ascend device context is null";
|
||||
ClearEnv();
|
||||
return FAILED;
|
||||
}
|
||||
g_ctx = ctx;
|
||||
#endif
|
||||
MSServiceImpl ms_service;
|
||||
grpc::EnableDefaultHealthCheckService(true);
|
||||
grpc::reflection::InitProtoReflectionServerBuilderPlugin();
|
||||
|
@ -276,13 +184,13 @@ Status Server::BuildAndStart() {
|
|||
serverBuilder.RegisterService(&ms_service);
|
||||
std::unique_ptr<grpc::Server> server(serverBuilder.BuildAndStart());
|
||||
if (server == nullptr) {
|
||||
MS_LOG(ERROR) << "The serving server create failed";
|
||||
MSI_LOG(ERROR) << "The serving server create failed";
|
||||
ClearEnv();
|
||||
return FAILED;
|
||||
}
|
||||
auto grpc_server_run = [&server]() { server->Wait(); };
|
||||
std::thread serving_thread(grpc_server_run);
|
||||
MS_LOG(INFO) << "MS Serving listening on " << server_address;
|
||||
MSI_LOG(INFO) << "MS Serving listening on " << server_address;
|
||||
auto exit_future = exit_requested.get_future();
|
||||
exit_future.wait();
|
||||
ClearEnv();
|
||||
|
|
|
@ -23,14 +23,21 @@
|
|||
#include "util/status.h"
|
||||
#include "version_control/model.h"
|
||||
#include "include/inference.h"
|
||||
#include "mindspore/ccsrc/debug/info.h"
|
||||
#include "serving/ms_service.pb.h"
|
||||
#include "serving/ms_service.grpc.pb.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace serving {
|
||||
|
||||
using ms_serving::PredictReply;
|
||||
using ms_serving::PredictRequest;
|
||||
|
||||
class Session {
|
||||
public:
|
||||
static Session &Instance();
|
||||
Status CreatDeviceSession(const std::string &device, uint32_t device_id);
|
||||
Status Predict(const std::vector<std::shared_ptr<inference::MSTensor>> &inputs, inference::MultiTensor *output);
|
||||
// Status Predict(const inference::MultiTensor &inputs, inference::MultiTensor &output);
|
||||
Status Predict(const PredictRequest &request, PredictReply &reply);
|
||||
Status Warmup(const MindSporeModelPtr model);
|
||||
Status Clear();
|
||||
|
||||
|
@ -38,8 +45,8 @@ class Session {
|
|||
Session() = default;
|
||||
~Session() = default;
|
||||
int sesseion_id_{0};
|
||||
std::shared_ptr<inference::MSSession> session_{nullptr};
|
||||
FuncGraphPtr last_graph_{nullptr};
|
||||
std::shared_ptr<inference::InferSession> session_{nullptr};
|
||||
bool model_loaded_ = false;
|
||||
uint32_t graph_id_{0};
|
||||
std::mutex mutex_;
|
||||
std::string device_type_;
|
||||
|
|
|
@ -0,0 +1,164 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "core/serving_tensor.h"
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "include/infer_log.h"
|
||||
|
||||
using std::string;
|
||||
using std::unordered_map;
|
||||
using std::vector;
|
||||
|
||||
namespace mindspore {
|
||||
namespace serving {
|
||||
|
||||
using inference::DataType;
|
||||
using inference::InferTensorBase;
|
||||
|
||||
const size_t kMaxShapeElementCount = INT32_MAX;
|
||||
const size_t kMaxDataBufferSize = UINT32_MAX;
|
||||
|
||||
ServingTensor::ServingTensor(ms_serving::Tensor &other) : tensor_(other) {}
|
||||
|
||||
ServingTensor::~ServingTensor() {}
|
||||
|
||||
DataType ServingTensor::data_type() const {
|
||||
const std::unordered_map<ms_serving::DataType, inference::DataType> type2id_map{
|
||||
{ms_serving::MS_UNKNOWN, inference::kMSI_Unknown}, {ms_serving::MS_BOOL, inference::kMSI_Bool},
|
||||
{ms_serving::MS_INT8, inference::kMSI_Int8}, {ms_serving::MS_UINT8, inference::kMSI_Uint8},
|
||||
{ms_serving::MS_INT16, inference::kMSI_Int16}, {ms_serving::MS_UINT16, inference::kMSI_Uint16},
|
||||
{ms_serving::MS_INT32, inference::kMSI_Int32}, {ms_serving::MS_UINT32, inference::kMSI_Uint32},
|
||||
{ms_serving::MS_INT64, inference::kMSI_Int64}, {ms_serving::MS_UINT64, inference::kMSI_Uint64},
|
||||
{ms_serving::MS_FLOAT16, inference::kMSI_Float16}, {ms_serving::MS_FLOAT32, inference::kMSI_Float32},
|
||||
{ms_serving::MS_FLOAT64, inference::kMSI_Float64},
|
||||
};
|
||||
auto it = type2id_map.find(tensor_.tensor_type());
|
||||
if (it == type2id_map.end()) {
|
||||
MSI_LOG_WARNING << "failed to get data type, undefined data type " << tensor_.tensor_type();
|
||||
return inference::kMSI_Unknown;
|
||||
} else {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
void ServingTensor::set_data_type(DataType data_type) {
|
||||
const std::unordered_map<inference::DataType, ms_serving::DataType> id2type_map{
|
||||
{inference::kMSI_Unknown, ms_serving::MS_UNKNOWN}, {inference::kMSI_Bool, ms_serving::MS_BOOL},
|
||||
{inference::kMSI_Float64, ms_serving::MS_FLOAT64}, {inference::kMSI_Int8, ms_serving::MS_INT8},
|
||||
{inference::kMSI_Uint8, ms_serving::MS_UINT8}, {inference::kMSI_Int16, ms_serving::MS_INT16},
|
||||
{inference::kMSI_Uint16, ms_serving::MS_UINT16}, {inference::kMSI_Int32, ms_serving::MS_INT32},
|
||||
{inference::kMSI_Uint32, ms_serving::MS_UINT32}, {inference::kMSI_Int64, ms_serving::MS_INT64},
|
||||
{inference::kMSI_Uint64, ms_serving::MS_UINT64}, {inference::kMSI_Float16, ms_serving::MS_FLOAT16},
|
||||
{inference::kMSI_Float32, ms_serving::MS_FLOAT32},
|
||||
};
|
||||
auto it = id2type_map.find(data_type);
|
||||
if (it == id2type_map.end()) {
|
||||
MSI_LOG_WARNING << "failed to set data type, undefined data type " << data_type;
|
||||
tensor_.set_tensor_type(ms_serving::MS_UNKNOWN);
|
||||
} else {
|
||||
tensor_.set_tensor_type(it->second);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> ServingTensor::shape() const {
|
||||
std::vector<int64_t> result;
|
||||
auto dims = tensor_.tensor_shape().dims();
|
||||
std::transform(dims.begin(), dims.end(), std::back_inserter(result), [](const int64_t dim) { return dim; });
|
||||
return result;
|
||||
}
|
||||
|
||||
void ServingTensor::set_shape(const std::vector<int64_t> &shape) {
|
||||
auto tensor_shape = tensor_.mutable_tensor_shape();
|
||||
tensor_shape->Clear();
|
||||
size_t element_count = 1;
|
||||
for (auto dim : shape) {
|
||||
if (dim <= 0 || element_count > kMaxShapeElementCount / dim) {
|
||||
MSI_LOG_ERROR << "failed to set shape, invalid dim num " << dim;
|
||||
tensor_shape->Clear();
|
||||
return;
|
||||
}
|
||||
element_count *= dim;
|
||||
tensor_shape->add_dims(dim);
|
||||
}
|
||||
}
|
||||
|
||||
bool ServingTensor::resize_data(size_t data_len) {
|
||||
string *buffer = tensor_.mutable_data();
|
||||
if (buffer == nullptr) {
|
||||
MSI_LOG_ERROR << "invalid buffer data";
|
||||
return false;
|
||||
}
|
||||
buffer->resize(data_len);
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t ServingTensor::data_size() const { return tensor_.data().size(); }
|
||||
|
||||
void *ServingTensor::mutable_data() { return const_cast<char *>(tensor_.mutable_data()->data()); }
|
||||
|
||||
const void *ServingTensor::data() const { return tensor_.data().data(); }
|
||||
|
||||
ServingRequest::ServingRequest(const ms_serving::PredictRequest &request) : request_(request) {
|
||||
auto &data = request_.data();
|
||||
std::transform(data.begin(), data.end(), std::back_inserter(cache_),
|
||||
[](const ms_serving::Tensor &item) { return ServingTensor(const_cast<ms_serving::Tensor &>(item)); });
|
||||
}
|
||||
|
||||
size_t ServingRequest::size() const { return request_.data_size(); }
|
||||
|
||||
const InferTensorBase *ServingRequest::operator[](size_t index) const {
|
||||
if (index >= cache_.size()) {
|
||||
MSI_LOG_ERROR << "visit invalid index " << index << " total size " << cache_.size();
|
||||
return nullptr;
|
||||
}
|
||||
return &(cache_[index]);
|
||||
}
|
||||
|
||||
size_t ServingReply::size() const { return cache_.size(); }
|
||||
|
||||
InferTensorBase *ServingReply::operator[](size_t index) {
|
||||
if (index >= cache_.size()) {
|
||||
MSI_LOG_ERROR << "visit invalid index " << index << " total size " << cache_.size();
|
||||
return nullptr;
|
||||
}
|
||||
return &(cache_[index]);
|
||||
}
|
||||
|
||||
const InferTensorBase *ServingReply::operator[](size_t index) const {
|
||||
if (index >= cache_.size()) {
|
||||
MSI_LOG_ERROR << "visit invalid index " << index << " total size " << cache_.size();
|
||||
return nullptr;
|
||||
}
|
||||
return &(cache_[index]);
|
||||
}
|
||||
|
||||
InferTensorBase *ServingReply::add() {
|
||||
auto new_item = reply_.add_result();
|
||||
if (new_item == nullptr) {
|
||||
MSI_LOG_ERROR << "add new item failed, current total size " << cache_.size();
|
||||
return nullptr;
|
||||
}
|
||||
cache_.push_back(ServingTensor(*new_item));
|
||||
return &(cache_.back());
|
||||
}
|
||||
|
||||
void ServingReply::clear() { reply_.mutable_result()->Clear(); }
|
||||
|
||||
} // namespace serving
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_SERVING_TENSOR_H_
|
||||
#define MINDSPORE_SERVING_TENSOR_H_
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "include/infer_tensor.h"
|
||||
#include "serving/ms_service.pb.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace serving {
|
||||
|
||||
class MS_API ServingTensor : public inference::InferTensorBase {
|
||||
public:
|
||||
// the other's lifetime must longer than this object
|
||||
explicit ServingTensor(ms_serving::Tensor &other);
|
||||
~ServingTensor();
|
||||
|
||||
inference::DataType data_type() const override;
|
||||
void set_data_type(inference::DataType type) override;
|
||||
std::vector<int64_t> shape() const override;
|
||||
void set_shape(const std::vector<int64_t> &shape) override;
|
||||
const void *data() const override;
|
||||
size_t data_size() const override;
|
||||
bool resize_data(size_t data_len) override;
|
||||
void *mutable_data() override;
|
||||
|
||||
private:
|
||||
// if tensor_ is reference from other ms_serving::Tensor, the other's lifetime must
|
||||
// longer than this object
|
||||
ms_serving::Tensor &tensor_;
|
||||
};
|
||||
|
||||
class ServingRequest : public inference::RequestBase {
|
||||
public:
|
||||
explicit ServingRequest(const ms_serving::PredictRequest &request);
|
||||
|
||||
size_t size() const override;
|
||||
const inference::InferTensorBase *operator[](size_t index) const override;
|
||||
|
||||
private:
|
||||
const ms_serving::PredictRequest &request_;
|
||||
std::vector<ServingTensor> cache_;
|
||||
};
|
||||
|
||||
class ServingReply : public inference::ReplyBase {
|
||||
public:
|
||||
explicit ServingReply(ms_serving::PredictReply &reply) : reply_(reply) {}
|
||||
|
||||
size_t size() const override;
|
||||
inference::InferTensorBase *operator[](size_t index) override;
|
||||
const inference::InferTensorBase *operator[](size_t index) const override;
|
||||
inference::InferTensorBase *add() override;
|
||||
void clear() override;
|
||||
|
||||
private:
|
||||
ms_serving::PredictReply &reply_;
|
||||
std::vector<ServingTensor> cache_;
|
||||
};
|
||||
|
||||
} // namespace serving
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_SERVING_TENSOR_H_
|
|
@ -25,43 +25,10 @@
|
|||
#include <ctime>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include "mindspore/ccsrc/utils/log_adapter.h"
|
||||
#include "include/infer_log.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace serving {
|
||||
char *ReadFile(const char *file, size_t *size) {
|
||||
if (file == nullptr) {
|
||||
MS_LOG(ERROR) << "file is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
MS_ASSERT(size != nullptr);
|
||||
std::string realPath = file;
|
||||
std::ifstream ifs(realPath);
|
||||
if (!ifs.good()) {
|
||||
MS_LOG(ERROR) << "file: " << realPath << " is not exist";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
MS_LOG(ERROR) << "file: " << realPath << "open failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
*size = ifs.tellg();
|
||||
std::unique_ptr<char> buf(new (std::nothrow) char[*size]);
|
||||
if (buf == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc buf failed, file: " << realPath;
|
||||
ifs.close();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(buf.get(), *size);
|
||||
ifs.close();
|
||||
|
||||
return buf.release();
|
||||
}
|
||||
|
||||
bool DirOrFileExist(const std::string &file_path) {
|
||||
int ret = access(file_path.c_str(), 0);
|
||||
|
@ -74,7 +41,7 @@ std::vector<std::string> GetAllSubDirs(const std::string &dir_path) {
|
|||
std::vector<std::string> SubDirs;
|
||||
|
||||
if ((dir = opendir(dir_path.c_str())) == NULL) {
|
||||
MS_LOG(ERROR) << "Open " << dir_path << " error!";
|
||||
MSI_LOG(ERROR) << "Open " << dir_path << " error!";
|
||||
return std::vector<std::string>();
|
||||
}
|
||||
|
||||
|
|
|
@ -19,10 +19,11 @@
|
|||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include "mindspore/ccsrc/utils/log_adapter.h"
|
||||
#include "include/infer_log.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace serving {
|
||||
|
||||
bool StartWith(const std::string &str, const std::string &expected) {
|
||||
return expected.empty() ||
|
||||
(str.size() >= expected.size() && memcmp(str.data(), expected.data(), expected.size()) == 0);
|
||||
|
|
|
@ -15,18 +15,19 @@
|
|||
*/
|
||||
#include "core/version_control/model.h"
|
||||
#include <string>
|
||||
#include "mindspore/ccsrc/utils/log_adapter.h"
|
||||
#include "include/infer_log.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace serving {
|
||||
|
||||
MindSporeModel::MindSporeModel(const std::string &model_name, const std::string &model_path,
|
||||
const std::string &model_version, const time_t &last_update_time)
|
||||
: model_name_(model_name),
|
||||
model_path_(model_path),
|
||||
model_version_(model_version),
|
||||
last_update_time_(last_update_time) {
|
||||
MS_LOG(INFO) << "init mindspore model, model_name = " << model_name_ << ", model_path = " << model_path_
|
||||
<< ", model_version = " << model_version_ << ", last_update_time = " << last_update_time_;
|
||||
MSI_LOG(INFO) << "init mindspore model, model_name = " << model_name_ << ", model_path = " << model_path_
|
||||
<< ", model_version = " << model_version_ << ", last_update_time = " << last_update_time_;
|
||||
}
|
||||
} // namespace serving
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,11 +20,12 @@
|
|||
#include <ctime>
|
||||
#include <memory>
|
||||
#include "util/file_system_operation.h"
|
||||
#include "mindspore/ccsrc/utils/log_adapter.h"
|
||||
#include "include/infer_log.h"
|
||||
#include "core/server.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace serving {
|
||||
|
||||
volatile bool stop_poll = false;
|
||||
|
||||
std::string GetVersionFromPath(const std::string &path) {
|
||||
|
@ -96,7 +97,7 @@ Status VersionController::Run() {
|
|||
|
||||
Status VersionController::CreateInitModels() {
|
||||
if (!DirOrFileExist(models_path_)) {
|
||||
MS_LOG(ERROR) << "Model Path Not Exist!" << std::endl;
|
||||
MSI_LOG(ERROR) << "Model Path Not Exist!" << std::endl;
|
||||
return FAILED;
|
||||
}
|
||||
std::vector<std::string> SubDirs = GetAllSubDirs(models_path_);
|
||||
|
@ -115,7 +116,7 @@ Status VersionController::CreateInitModels() {
|
|||
}
|
||||
}
|
||||
if (valid_models_.empty()) {
|
||||
MS_LOG(ERROR) << "There is no valid model for serving";
|
||||
MSI_LOG(ERROR) << "There is no valid model for serving";
|
||||
return FAILED;
|
||||
}
|
||||
auto ret = Session::Instance().Warmup(valid_models_.back());
|
||||
|
|
|
@ -8,33 +8,33 @@ add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
|||
find_package(Threads REQUIRED)
|
||||
|
||||
# This branch assumes that gRPC and all its dependencies are already installed
|
||||
# on this system, so they can be located by find_package().
|
||||
# on this system, so they can be located by find_package().
|
||||
|
||||
# Find Protobuf installation
|
||||
# Looks for protobuf-config.cmake file installed by Protobuf's cmake installation.
|
||||
set(protobuf_MODULE_COMPATIBLE TRUE)
|
||||
find_package(Protobuf CONFIG REQUIRED)
|
||||
message(STATUS "Using protobuf ${protobuf_VERSION}")
|
||||
# Find Protobuf installation
|
||||
# Looks for protobuf-config.cmake file installed by Protobuf's cmake installation.
|
||||
set(protobuf_MODULE_COMPATIBLE TRUE)
|
||||
find_package(Protobuf CONFIG REQUIRED)
|
||||
message(STATUS "Using protobuf ${protobuf_VERSION}")
|
||||
|
||||
set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf)
|
||||
set(_REFLECTION gRPC::grpc++_reflection)
|
||||
if(CMAKE_CROSSCOMPILING)
|
||||
find_program(_PROTOBUF_PROTOC protoc)
|
||||
else()
|
||||
set(_PROTOBUF_PROTOC $<TARGET_FILE:protobuf::protoc>)
|
||||
endif()
|
||||
set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf)
|
||||
set(_REFLECTION gRPC::grpc++_reflection)
|
||||
if (CMAKE_CROSSCOMPILING)
|
||||
find_program(_PROTOBUF_PROTOC protoc)
|
||||
else ()
|
||||
set(_PROTOBUF_PROTOC $<TARGET_FILE:protobuf::protoc>)
|
||||
endif ()
|
||||
|
||||
# Find gRPC installation
|
||||
# Looks for gRPCConfig.cmake file installed by gRPC's cmake installation.
|
||||
find_package(gRPC CONFIG REQUIRED)
|
||||
message(STATUS "Using gRPC ${gRPC_VERSION}")
|
||||
# Find gRPC installation
|
||||
# Looks for gRPCConfig.cmake file installed by gRPC's cmake installation.
|
||||
find_package(gRPC CONFIG REQUIRED)
|
||||
message(STATUS "Using gRPC ${gRPC_VERSION}")
|
||||
|
||||
set(_GRPC_GRPCPP gRPC::grpc++)
|
||||
if(CMAKE_CROSSCOMPILING)
|
||||
find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin)
|
||||
else()
|
||||
set(_GRPC_CPP_PLUGIN_EXECUTABLE $<TARGET_FILE:gRPC::grpc_cpp_plugin>)
|
||||
endif()
|
||||
set(_GRPC_GRPCPP gRPC::grpc++)
|
||||
if (CMAKE_CROSSCOMPILING)
|
||||
find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin)
|
||||
else ()
|
||||
set(_GRPC_CPP_PLUGIN_EXECUTABLE $<TARGET_FILE:gRPC::grpc_cpp_plugin>)
|
||||
endif ()
|
||||
|
||||
# Proto file
|
||||
get_filename_component(hw_proto "../ms_service.proto" ABSOLUTE)
|
||||
|
@ -59,7 +59,7 @@ add_custom_command(
|
|||
include_directories("${CMAKE_CURRENT_BINARY_DIR}")
|
||||
|
||||
# Targets greeter_[async_](client|server)
|
||||
foreach(_target
|
||||
foreach (_target
|
||||
ms_client ms_server)
|
||||
add_executable(${_target} "${_target}.cc"
|
||||
${hw_proto_srcs}
|
||||
|
@ -68,4 +68,4 @@ foreach(_target
|
|||
${_REFLECTION}
|
||||
${_GRPC_GRPCPP}
|
||||
${_PROTOBUF_LIBPROTOBUF})
|
||||
endforeach()
|
||||
endforeach ()
|
||||
|
|
|
@ -211,12 +211,77 @@ PredictRequest ReadBertInput() {
|
|||
return request;
|
||||
}
|
||||
|
||||
PredictRequest ReadLenetInput() {
|
||||
size_t size;
|
||||
auto buf = ReadFile("lenet_img.bin", &size);
|
||||
if (buf == nullptr) {
|
||||
std::cout << "read file failed" << std::endl;
|
||||
return PredictRequest();
|
||||
}
|
||||
PredictRequest request;
|
||||
auto cur = buf;
|
||||
if (size > 0) {
|
||||
Tensor data;
|
||||
TensorShape shape;
|
||||
// set type
|
||||
data.set_tensor_type(ms_serving::MS_FLOAT32);
|
||||
|
||||
// set shape
|
||||
shape.add_dims(size / sizeof(float));
|
||||
*data.mutable_tensor_shape() = shape;
|
||||
|
||||
// set data
|
||||
data.set_data(cur, size);
|
||||
*request.add_data() = data;
|
||||
}
|
||||
std::cout << "get input data size " << size << std::endl;
|
||||
return request;
|
||||
}
|
||||
|
||||
PredictRequest ReadOtherInput(const std::string &data_file) {
|
||||
size_t size;
|
||||
auto buf = ReadFile(data_file.c_str(), &size);
|
||||
if (buf == nullptr) {
|
||||
std::cout << "read file failed" << std::endl;
|
||||
return PredictRequest();
|
||||
}
|
||||
PredictRequest request;
|
||||
auto cur = buf;
|
||||
if (size > 0) {
|
||||
Tensor data;
|
||||
TensorShape shape;
|
||||
// set type
|
||||
data.set_tensor_type(ms_serving::MS_FLOAT32);
|
||||
|
||||
// set shape
|
||||
shape.add_dims(size / sizeof(float));
|
||||
*data.mutable_tensor_shape() = shape;
|
||||
|
||||
// set data
|
||||
data.set_data(cur, size);
|
||||
*request.add_data() = data;
|
||||
}
|
||||
std::cout << "get input data size " << size << std::endl;
|
||||
return request;
|
||||
}
|
||||
|
||||
template <class DT>
|
||||
void print_array_item(const DT *data, size_t size) {
|
||||
for (size_t i = 0; i < size && i < 100; i++) {
|
||||
std::cout << data[i] << '\t';
|
||||
if ((i + 1) % 10 == 0) {
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
class MSClient {
|
||||
public:
|
||||
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
|
||||
~MSClient() = default;
|
||||
|
||||
std::string Predict(const std::string &type) {
|
||||
std::string Predict(const std::string &type, const std::string &data_file) {
|
||||
// Data we are sending to the server.
|
||||
PredictRequest request;
|
||||
if (type == "add") {
|
||||
|
@ -234,6 +299,10 @@ class MSClient {
|
|||
*request.add_data() = data;
|
||||
} else if (type == "bert") {
|
||||
request = ReadBertInput();
|
||||
} else if (type == "lenet") {
|
||||
request = ReadLenetInput();
|
||||
} else if (type == "other") {
|
||||
request = ReadOtherInput(data_file);
|
||||
} else {
|
||||
std::cout << "type only support bert or add, but input is " << type << std::endl;
|
||||
}
|
||||
|
@ -256,6 +325,20 @@ class MSClient {
|
|||
|
||||
// Act upon its status.
|
||||
if (status.ok()) {
|
||||
for (size_t i = 0; i < reply.result_size(); i++) {
|
||||
auto result = reply.result(i);
|
||||
if (result.tensor_type() == ms_serving::DataType::MS_FLOAT32) {
|
||||
print_array_item(reinterpret_cast<const float *>(result.data().data()), result.data().size() / sizeof(float));
|
||||
} else if (result.tensor_type() == ms_serving::DataType::MS_INT32) {
|
||||
print_array_item(reinterpret_cast<const int32_t *>(result.data().data()),
|
||||
result.data().size() / sizeof(int32_t));
|
||||
} else if (result.tensor_type() == ms_serving::DataType::MS_UINT32) {
|
||||
print_array_item(reinterpret_cast<const uint32_t *>(result.data().data()),
|
||||
result.data().size() / sizeof(uint32_t));
|
||||
} else {
|
||||
std::cout << "output datatype " << result.tensor_type() << std::endl;
|
||||
}
|
||||
}
|
||||
return "RPC OK";
|
||||
} else {
|
||||
std::cout << status.error_code() << ": " << status.error_message() << std::endl;
|
||||
|
@ -277,6 +360,8 @@ int main(int argc, char **argv) {
|
|||
std::string arg_target_str("--target");
|
||||
std::string type;
|
||||
std::string arg_type_str("--type");
|
||||
std::string arg_data_str("--data");
|
||||
std::string data = "default_data.bin";
|
||||
if (argc > 2) {
|
||||
{
|
||||
// parse target
|
||||
|
@ -304,19 +389,33 @@ int main(int argc, char **argv) {
|
|||
if (arg_val2[start_pos] == '=') {
|
||||
type = arg_val2.substr(start_pos + 1);
|
||||
} else {
|
||||
std::cout << "The only correct argument syntax is --target=" << std::endl;
|
||||
std::cout << "The only correct argument syntax is --type=" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
} else {
|
||||
type = "add";
|
||||
}
|
||||
}
|
||||
if (argc > 3) {
|
||||
// parse type
|
||||
std::string arg_val3 = argv[3];
|
||||
size_t start_pos = arg_val3.find(arg_data_str);
|
||||
if (start_pos != std::string::npos) {
|
||||
start_pos += arg_data_str.size();
|
||||
if (arg_val3[start_pos] == '=') {
|
||||
data = arg_val3.substr(start_pos + 1);
|
||||
} else {
|
||||
std::cout << "The only correct argument syntax is --data=" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
target_str = "localhost:5500";
|
||||
type = "add";
|
||||
}
|
||||
MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()));
|
||||
std::string reply = client.Predict(type);
|
||||
std::string reply = client.Predict(type, data);
|
||||
std::cout << "client received: " << reply << std::endl;
|
||||
|
||||
return 0;
|
||||
|
|
Loading…
Reference in New Issue