add bert example

This commit is contained in:
dinghao 2020-06-23 09:45:22 +08:00
parent 35ab95bfae
commit 4ddb00b996
4 changed files with 255 additions and 38 deletions

View File

@ -247,7 +247,7 @@ add_library(inference SHARED
${CMAKE_CURRENT_SOURCE_DIR}/session/session.cc ${CMAKE_CURRENT_SOURCE_DIR}/session/session.cc
${LOAD_ONNX_SRC} ${LOAD_ONNX_SRC}
) )
target_link_libraries(inference PRIVATE ${PYTHON_LIBRARY} ${SECUREC_LIBRARY} target_link_libraries(inference PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
-Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_gvar mindspore::protobuf) -Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_gvar mindspore::protobuf)
if (ENABLE_CPU) if (ENABLE_CPU)

View File

@ -2,9 +2,8 @@ cmake_minimum_required(VERSION 3.5.1)
project(HelloWorld C CXX) project(HelloWorld C CXX)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
@ -69,4 +68,4 @@ foreach(_target
${_REFLECTION} ${_REFLECTION}
${_GRPC_GRPCPP} ${_GRPC_GRPCPP}
${_PROTOBUF_LIBPROTOBUF}) ${_PROTOBUF_LIBPROTOBUF})
endforeach() endforeach()

View File

@ -15,8 +15,10 @@
*/ */
#include <grpcpp/grpcpp.h> #include <grpcpp/grpcpp.h>
#include <iostream> #include <iostream>
#include <vector>
#include "serving/ms_service.grpc.pb.h" #include <string>
#include <fstream>
#include "./ms_service.grpc.pb.h"
using grpc::Channel; using grpc::Channel;
using grpc::ClientContext; using grpc::ClientContext;
@ -27,26 +29,214 @@ using ms_serving::PredictRequest;
using ms_serving::Tensor; using ms_serving::Tensor;
using ms_serving::TensorShape; using ms_serving::TensorShape;
enum TypeId : int {
kTypeUnknown = 0,
kMetaTypeBegin = kTypeUnknown,
kMetaTypeType, // Type
kMetaTypeAnything,
kMetaTypeObject,
kMetaTypeTypeType, // TypeType
kMetaTypeProblem,
kMetaTypeExternal,
kMetaTypeNone,
kMetaTypeNull,
kMetaTypeEllipsis,
kMetaTypeEnd,
//
// Object types
//
kObjectTypeBegin = kMetaTypeEnd,
kObjectTypeNumber,
kObjectTypeString,
kObjectTypeList,
kObjectTypeTuple,
kObjectTypeSlice,
kObjectTypeKeyword,
kObjectTypeTensorType,
kObjectTypeClass,
kObjectTypeDictionary,
kObjectTypeFunction,
kObjectTypeJTagged,
kObjectTypeSymbolicKeyType,
kObjectTypeEnvType,
kObjectTypeRefKey,
kObjectTypeRef,
kObjectTypeEnd,
//
// Number Types
//
kNumberTypeBegin = kObjectTypeEnd,
kNumberTypeBool,
kNumberTypeInt,
kNumberTypeInt8,
kNumberTypeInt16,
kNumberTypeInt32,
kNumberTypeInt64,
kNumberTypeUInt,
kNumberTypeUInt8,
kNumberTypeUInt16,
kNumberTypeUInt32,
kNumberTypeUInt64,
kNumberTypeFloat,
kNumberTypeFloat16,
kNumberTypeFloat32,
kNumberTypeFloat64,
kNumberTypeEnd
};
std::string RealPath(const char *path) {
if (path == nullptr) {
std::cout << "path is nullptr";
return "";
}
if ((strlen(path)) >= PATH_MAX) {
std::cout << "path is too long";
return "";
}
std::shared_ptr<char> resolvedPath(new (std::nothrow) char[PATH_MAX]{0});
if (resolvedPath == nullptr) {
std::cout << "new resolvedPath failed";
return "";
}
auto ret = realpath(path, resolvedPath.get());
if (ret == nullptr) {
std::cout << "realpath failed";
return "";
}
return resolvedPath.get();
}
char *ReadFile(const char *file, size_t *size) {
if (file == nullptr) {
std::cout << "file is nullptr" << std::endl;
return nullptr;
}
if (size == nullptr) {
std::cout << "size should not be nullptr" << std::endl;
return nullptr;
}
std::ifstream ifs(RealPath(file));
if (!ifs.good()) {
std::cout << "file: " << file << "is not exist";
return nullptr;
}
if (!ifs.is_open()) {
std::cout << "file: " << file << "open failed";
return nullptr;
}
ifs.seekg(0, std::ios::end);
*size = ifs.tellg();
std::unique_ptr<char> buf(new (std::nothrow) char[*size]);
if (buf == nullptr) {
std::cout << "malloc buf failed, file: " << file;
ifs.close();
return nullptr;
}
ifs.seekg(0, std::ios::beg);
ifs.read(buf.get(), *size);
ifs.close();
return buf.release();
}
const std::map<TypeId, ms_serving::DataType> id2type_map{
{TypeId::kNumberTypeBegin, ms_serving::MS_UNKNOWN}, {TypeId::kNumberTypeBool, ms_serving::MS_BOOL},
{TypeId::kNumberTypeInt8, ms_serving::MS_INT8}, {TypeId::kNumberTypeUInt8, ms_serving::MS_UINT8},
{TypeId::kNumberTypeInt16, ms_serving::MS_INT16}, {TypeId::kNumberTypeUInt16, ms_serving::MS_UINT16},
{TypeId::kNumberTypeInt32, ms_serving::MS_INT32}, {TypeId::kNumberTypeUInt32, ms_serving::MS_UINT32},
{TypeId::kNumberTypeInt64, ms_serving::MS_INT64}, {TypeId::kNumberTypeUInt64, ms_serving::MS_UINT64},
{TypeId::kNumberTypeFloat16, ms_serving::MS_FLOAT16}, {TypeId::kNumberTypeFloat32, ms_serving::MS_FLOAT32},
{TypeId::kNumberTypeFloat64, ms_serving::MS_FLOAT64},
};
int WriteFile(const void *buf, size_t size) {
auto fd = fopen("output.json", "a+");
if (fd == NULL) {
std::cout << "fd is null and open file fail" << std::endl;
return 0;
}
fwrite(buf, size, 1, fd);
fclose(fd);
return 0;
}
PredictRequest ReadBertInput() {
size_t size;
auto buf = ReadFile("input206.json", &size);
if (buf == nullptr) {
std::cout << "read file failed" << std::endl;
return PredictRequest();
}
PredictRequest request;
auto cur = buf;
while (size > 0) {
if (request.data_size() == 4) {
break;
}
Tensor data;
TensorShape shape;
// set type
int type = *(reinterpret_cast<int *>(cur));
cur = cur + sizeof(int);
size = size - sizeof(int);
ms_serving::DataType dataType = id2type_map.at(TypeId(type));
data.set_tensor_type(dataType);
// set shape
size_t dims = *(reinterpret_cast<size_t *>(cur));
cur = cur + sizeof(size_t);
size = size - sizeof(size_t);
for (size_t i = 0; i < dims; i++) {
int dim = *(reinterpret_cast<int *>(cur));
shape.add_dims(dim);
cur = cur + sizeof(int);
size = size - sizeof(int);
}
*data.mutable_tensor_shape() = shape;
// set data
size_t data_len = *(reinterpret_cast<size_t *>(cur));
cur = cur + sizeof(size_t);
size = size - sizeof(size_t);
data.set_data(cur, data_len);
cur = cur + data_len;
size = size - data_len;
*request.add_data() = data;
}
return request;
}
class MSClient { class MSClient {
public: public:
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {} explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
std::string Predict(const std::string &user) { std::string Predict(const std::string &type) {
// Data we are sending to the server. // Data we are sending to the server.
PredictRequest request; PredictRequest request;
Tensor data; if (type == "add") {
TensorShape shape; Tensor data;
shape.add_dims(1); TensorShape shape;
shape.add_dims(1); shape.add_dims(1);
shape.add_dims(2); shape.add_dims(1);
shape.add_dims(2); shape.add_dims(2);
*data.mutable_tensor_shape() = shape; shape.add_dims(2);
data.set_tensor_type(ms_serving::MS_FLOAT32); *data.mutable_tensor_shape() = shape;
vector<float> input_data{1.1, 2.1, 3.1, 4.1}; data.set_tensor_type(ms_serving::MS_FLOAT32);
data.set_data(input_data.data(), input_data.size()); std::vector<float> input_data{1.1, 2.1, 3.1, 4.1};
*request.add_data() = data; data.set_data(input_data.data(), input_data.size());
*request.add_data() = data; *request.add_data() = data;
*request.add_data() = data;
} else if (type == "bert") {
request = ReadBertInput();
} else {
std::cout << "type only support bert or add, but input is " << type << std::endl;
}
std::cout << "intput tensor size is " << request.data_size() << std::endl;
// Container for the data we expect from the server. // Container for the data we expect from the server.
PredictReply reply; PredictReply reply;
@ -57,6 +247,12 @@ class MSClient {
// The actual RPC. // The actual RPC.
Status status = stub_->Predict(&context, request, &reply); Status status = stub_->Predict(&context, request, &reply);
for (int i = 0; i < reply.result_size(); i++) {
WriteFile(reply.result(i).data().data(), reply.result(i).data().size());
}
std::cout << "the return result size is " << reply.result_size() << std::endl;
// Act upon its status. // Act upon its status.
if (status.ok()) { if (status.ok()) {
return "RPC OK"; return "RPC OK";
@ -77,28 +273,50 @@ int main(int argc, char **argv) {
// We indicate that the channel isn't authenticated (use of // We indicate that the channel isn't authenticated (use of
// InsecureChannelCredentials()). // InsecureChannelCredentials()).
std::string target_str; std::string target_str;
std::string arg_str("--target"); std::string arg_target_str("--target");
if (argc > 1) { std::string type;
std::string arg_val = argv[1]; std::string arg_type_str("--type");
size_t start_pos = arg_val.find(arg_str); if (argc > 2) {
if (start_pos != std::string::npos) { {
start_pos += arg_str.size(); // parse target
if (arg_val[start_pos] == '=') { std::string arg_val = argv[1];
target_str = arg_val.substr(start_pos + 1); size_t start_pos = arg_val.find(arg_target_str);
if (start_pos != std::string::npos) {
start_pos += arg_target_str.size();
if (arg_val[start_pos] == '=') {
target_str = arg_val.substr(start_pos + 1);
} else {
std::cout << "The only correct argument syntax is --target=" << std::endl;
return 0;
}
} else { } else {
std::cout << "The only correct argument syntax is --target=" << std::endl; target_str = "localhost:5500";
return 0;
} }
} else {
std::cout << "The only acceptable argument is --target=" << std::endl;
return 0;
} }
{
// parse type
std::string arg_val2 = argv[2];
size_t start_pos = arg_val2.find(arg_type_str);
if (start_pos != std::string::npos) {
start_pos += arg_type_str.size();
if (arg_val2[start_pos] == '=') {
type = arg_val2.substr(start_pos + 1);
} else {
std::cout << "The only correct argument syntax is --target=" << std::endl;
return 0;
}
} else {
type = "add";
}
}
} else { } else {
target_str = "localhost:85010"; target_str = "localhost:5500";
type = "add";
} }
MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials())); MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()));
string request; std::string reply = client.Predict(type);
string reply = client.Predict(request);
std::cout << "client received: " << reply << std::endl; std::cout << "client received: " << reply << std::endl;
return 0; return 0;

View File

@ -18,7 +18,7 @@
#include <grpcpp/ext/proto_server_reflection_plugin.h> #include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <iostream> #include <iostream>
#include "serving/ms_service.grpc.pb.h" #include "./ms_service.grpc.pb.h"
using grpc::Server; using grpc::Server;
using grpc::ServerBuilder; using grpc::ServerBuilder;
@ -31,7 +31,7 @@ using ms_serving::PredictRequest;
// Logic and data behind the server's behavior. // Logic and data behind the server's behavior.
class MSServiceImpl final : public MSService::Service { class MSServiceImpl final : public MSService::Service {
Status Predict(ServerContext *context, const PredictRequest *request, PredictReply *reply) override { Status Predict(ServerContext *context, const PredictRequest *request, PredictReply *reply) override {
cout << "server eval" << endl; std::cout << "server eval" << std::endl;
return Status::OK; return Status::OK;
} }
}; };