From c1a518ce24a3832f457399efc8660334cbe38a4a Mon Sep 17 00:00:00 2001 From: dinghao Date: Tue, 23 Jun 2020 09:45:22 +0800 Subject: [PATCH] add bert example --- mindspore/ccsrc/CMakeLists.txt | 2 +- serving/cpp_example/CMakeLists.txt | 5 +- serving/cpp_example/ms_client.cc | 282 +++++++++++++++++++++++++---- serving/cpp_example/ms_server.cc | 4 +- 4 files changed, 255 insertions(+), 38 deletions(-) diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 80f82fd7ea8..9409a399234 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -247,7 +247,7 @@ add_library(inference SHARED ${CMAKE_CURRENT_SOURCE_DIR}/session/session.cc ${LOAD_ONNX_SRC} ) -target_link_libraries(inference PRIVATE ${PYTHON_LIBRARY} ${SECUREC_LIBRARY} +target_link_libraries(inference PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY} -Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_gvar mindspore::protobuf) if (ENABLE_CPU) diff --git a/serving/cpp_example/CMakeLists.txt b/serving/cpp_example/CMakeLists.txt index a45bd123d61..aaf02778802 100644 --- a/serving/cpp_example/CMakeLists.txt +++ b/serving/cpp_example/CMakeLists.txt @@ -2,9 +2,8 @@ cmake_minimum_required(VERSION 3.5.1) project(HelloWorld C CXX) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") - +add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) find_package(Threads REQUIRED) @@ -69,4 +68,4 @@ foreach(_target ${_REFLECTION} ${_GRPC_GRPCPP} ${_PROTOBUF_LIBPROTOBUF}) -endforeach() \ No newline at end of file +endforeach() diff --git a/serving/cpp_example/ms_client.cc b/serving/cpp_example/ms_client.cc index 035a145714d..846dd618eba 100644 --- a/serving/cpp_example/ms_client.cc +++ b/serving/cpp_example/ms_client.cc @@ -15,8 +15,10 @@ */ #include #include - -#include "serving/ms_service.grpc.pb.h" +#include +#include +#include +#include "./ms_service.grpc.pb.h" using grpc::Channel; using grpc::ClientContext; @@ -27,26 +29,214 @@ using ms_serving::PredictRequest; using ms_serving::Tensor; using ms_serving::TensorShape; +enum TypeId : int { + kTypeUnknown = 0, + kMetaTypeBegin = kTypeUnknown, + kMetaTypeType, // Type + kMetaTypeAnything, + kMetaTypeObject, + kMetaTypeTypeType, // TypeType + kMetaTypeProblem, + kMetaTypeExternal, + kMetaTypeNone, + kMetaTypeNull, + kMetaTypeEllipsis, + kMetaTypeEnd, + // + // Object types + // + kObjectTypeBegin = kMetaTypeEnd, + kObjectTypeNumber, + kObjectTypeString, + kObjectTypeList, + kObjectTypeTuple, + kObjectTypeSlice, + kObjectTypeKeyword, + kObjectTypeTensorType, + kObjectTypeClass, + kObjectTypeDictionary, + kObjectTypeFunction, + kObjectTypeJTagged, + kObjectTypeSymbolicKeyType, + kObjectTypeEnvType, + kObjectTypeRefKey, + kObjectTypeRef, + kObjectTypeEnd, + // + // Number Types + // + kNumberTypeBegin = kObjectTypeEnd, + kNumberTypeBool, + kNumberTypeInt, + kNumberTypeInt8, + kNumberTypeInt16, + kNumberTypeInt32, + kNumberTypeInt64, + kNumberTypeUInt, + kNumberTypeUInt8, + kNumberTypeUInt16, + kNumberTypeUInt32, + kNumberTypeUInt64, + kNumberTypeFloat, + kNumberTypeFloat16, + kNumberTypeFloat32, + kNumberTypeFloat64, + kNumberTypeEnd +}; + +std::string RealPath(const char *path) { + if (path == nullptr) { + std::cout << "path is nullptr"; + return ""; + } + if ((strlen(path)) >= PATH_MAX) { + std::cout << "path is too long"; + return ""; + } + + std::shared_ptr resolvedPath(new (std::nothrow) char[PATH_MAX]{0}); + if (resolvedPath == nullptr) { + std::cout << "new resolvedPath failed"; + return ""; + } + + auto ret = realpath(path, resolvedPath.get()); + if (ret == nullptr) { + std::cout << "realpath failed"; + return ""; + } + return resolvedPath.get(); +} + +char *ReadFile(const char *file, size_t *size) { + if (file == nullptr) { + std::cout << "file is nullptr" << std::endl; + return nullptr; + } + if (size == nullptr) { + std::cout << "size should not be nullptr" << std::endl; + return nullptr; + } + std::ifstream ifs(RealPath(file)); + if (!ifs.good()) { + std::cout << "file: " << file << "is not exist"; + return nullptr; + } + + if (!ifs.is_open()) { + std::cout << "file: " << file << "open failed"; + return nullptr; + } + + ifs.seekg(0, std::ios::end); + *size = ifs.tellg(); + std::unique_ptr buf(new (std::nothrow) char[*size]); + if (buf == nullptr) { + std::cout << "malloc buf failed, file: " << file; + ifs.close(); + return nullptr; + } + + ifs.seekg(0, std::ios::beg); + ifs.read(buf.get(), *size); + ifs.close(); + + return buf.release(); +} +const std::map id2type_map{ + {TypeId::kNumberTypeBegin, ms_serving::MS_UNKNOWN}, {TypeId::kNumberTypeBool, ms_serving::MS_BOOL}, + {TypeId::kNumberTypeInt8, ms_serving::MS_INT8}, {TypeId::kNumberTypeUInt8, ms_serving::MS_UINT8}, + {TypeId::kNumberTypeInt16, ms_serving::MS_INT16}, {TypeId::kNumberTypeUInt16, ms_serving::MS_UINT16}, + {TypeId::kNumberTypeInt32, ms_serving::MS_INT32}, {TypeId::kNumberTypeUInt32, ms_serving::MS_UINT32}, + {TypeId::kNumberTypeInt64, ms_serving::MS_INT64}, {TypeId::kNumberTypeUInt64, ms_serving::MS_UINT64}, + {TypeId::kNumberTypeFloat16, ms_serving::MS_FLOAT16}, {TypeId::kNumberTypeFloat32, ms_serving::MS_FLOAT32}, + {TypeId::kNumberTypeFloat64, ms_serving::MS_FLOAT64}, +}; + +int WriteFile(const void *buf, size_t size) { + auto fd = fopen("output.json", "a+"); + if (fd == NULL) { + std::cout << "fd is null and open file fail" << std::endl; + return 0; + } + fwrite(buf, size, 1, fd); + fclose(fd); + return 0; +} + +PredictRequest ReadBertInput() { + size_t size; + auto buf = ReadFile("input206.json", &size); + if (buf == nullptr) { + std::cout << "read file failed" << std::endl; + return PredictRequest(); + } + PredictRequest request; + auto cur = buf; + while (size > 0) { + if (request.data_size() == 4) { + break; + } + Tensor data; + TensorShape shape; + // set type + int type = *(reinterpret_cast(cur)); + cur = cur + sizeof(int); + size = size - sizeof(int); + ms_serving::DataType dataType = id2type_map.at(TypeId(type)); + data.set_tensor_type(dataType); + + // set shape + size_t dims = *(reinterpret_cast(cur)); + cur = cur + sizeof(size_t); + size = size - sizeof(size_t); + + for (size_t i = 0; i < dims; i++) { + int dim = *(reinterpret_cast(cur)); + shape.add_dims(dim); + cur = cur + sizeof(int); + size = size - sizeof(int); + } + *data.mutable_tensor_shape() = shape; + + // set data + size_t data_len = *(reinterpret_cast(cur)); + cur = cur + sizeof(size_t); + size = size - sizeof(size_t); + data.set_data(cur, data_len); + cur = cur + data_len; + size = size - data_len; + *request.add_data() = data; + } + return request; +} + class MSClient { public: explicit MSClient(std::shared_ptr channel) : stub_(MSService::NewStub(channel)) {} - std::string Predict(const std::string &user) { + std::string Predict(const std::string &type) { // Data we are sending to the server. PredictRequest request; - Tensor data; - TensorShape shape; - shape.add_dims(1); - shape.add_dims(1); - shape.add_dims(2); - shape.add_dims(2); - *data.mutable_tensor_shape() = shape; - data.set_tensor_type(ms_serving::MS_FLOAT32); - vector input_data{1.1, 2.1, 3.1, 4.1}; - data.set_data(input_data.data(), input_data.size()); - *request.add_data() = data; - *request.add_data() = data; - + if (type == "add") { + Tensor data; + TensorShape shape; + shape.add_dims(1); + shape.add_dims(1); + shape.add_dims(2); + shape.add_dims(2); + *data.mutable_tensor_shape() = shape; + data.set_tensor_type(ms_serving::MS_FLOAT32); + std::vector input_data{1.1, 2.1, 3.1, 4.1}; + data.set_data(input_data.data(), input_data.size()); + *request.add_data() = data; + *request.add_data() = data; + } else if (type == "bert") { + request = ReadBertInput(); + } else { + std::cout << "type only support bert or add, but input is " << type << std::endl; + } + std::cout << "intput tensor size is " << request.data_size() << std::endl; // Container for the data we expect from the server. PredictReply reply; @@ -57,6 +247,12 @@ class MSClient { // The actual RPC. Status status = stub_->Predict(&context, request, &reply); + for (int i = 0; i < reply.result_size(); i++) { + WriteFile(reply.result(i).data().data(), reply.result(i).data().size()); + } + + std::cout << "the return result size is " << reply.result_size() << std::endl; + // Act upon its status. if (status.ok()) { return "RPC OK"; @@ -77,28 +273,50 @@ int main(int argc, char **argv) { // We indicate that the channel isn't authenticated (use of // InsecureChannelCredentials()). std::string target_str; - std::string arg_str("--target"); - if (argc > 1) { - std::string arg_val = argv[1]; - size_t start_pos = arg_val.find(arg_str); - if (start_pos != std::string::npos) { - start_pos += arg_str.size(); - if (arg_val[start_pos] == '=') { - target_str = arg_val.substr(start_pos + 1); + std::string arg_target_str("--target"); + std::string type; + std::string arg_type_str("--type"); + if (argc > 2) { + { + // parse target + std::string arg_val = argv[1]; + size_t start_pos = arg_val.find(arg_target_str); + if (start_pos != std::string::npos) { + start_pos += arg_target_str.size(); + if (arg_val[start_pos] == '=') { + target_str = arg_val.substr(start_pos + 1); + } else { + std::cout << "The only correct argument syntax is --target=" << std::endl; + return 0; + } } else { - std::cout << "The only correct argument syntax is --target=" << std::endl; - return 0; + target_str = "localhost:5500"; } - } else { - std::cout << "The only acceptable argument is --target=" << std::endl; - return 0; } + + { + // parse type + std::string arg_val2 = argv[2]; + size_t start_pos = arg_val2.find(arg_type_str); + if (start_pos != std::string::npos) { + start_pos += arg_type_str.size(); + if (arg_val2[start_pos] == '=') { + type = arg_val2.substr(start_pos + 1); + } else { + std::cout << "The only correct argument syntax is --target=" << std::endl; + return 0; + } + } else { + type = "add"; + } + } + } else { - target_str = "localhost:85010"; + target_str = "localhost:5500"; + type = "add"; } MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials())); - string request; - string reply = client.Predict(request); + std::string reply = client.Predict(type); std::cout << "client received: " << reply << std::endl; return 0; diff --git a/serving/cpp_example/ms_server.cc b/serving/cpp_example/ms_server.cc index 94ba949832d..f6021ef0009 100644 --- a/serving/cpp_example/ms_server.cc +++ b/serving/cpp_example/ms_server.cc @@ -18,7 +18,7 @@ #include #include -#include "serving/ms_service.grpc.pb.h" +#include "./ms_service.grpc.pb.h" using grpc::Server; using grpc::ServerBuilder; @@ -31,7 +31,7 @@ using ms_serving::PredictRequest; // Logic and data behind the server's behavior. class MSServiceImpl final : public MSService::Service { Status Predict(ServerContext *context, const PredictRequest *request, PredictReply *reply) override { - cout << "server eval" << endl; + std::cout << "server eval" << std::endl; return Status::OK; } };