forked from mindspore-Ecosystem/mindspore
add bert example
This commit is contained in:
parent
35ab95bfae
commit
4ddb00b996
|
@ -247,7 +247,7 @@ add_library(inference SHARED
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/session/session.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/session/session.cc
|
||||||
${LOAD_ONNX_SRC}
|
${LOAD_ONNX_SRC}
|
||||||
)
|
)
|
||||||
target_link_libraries(inference PRIVATE ${PYTHON_LIBRARY} ${SECUREC_LIBRARY}
|
target_link_libraries(inference PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
|
||||||
-Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_gvar mindspore::protobuf)
|
-Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_gvar mindspore::protobuf)
|
||||||
|
|
||||||
if (ENABLE_CPU)
|
if (ENABLE_CPU)
|
||||||
|
|
|
@ -2,9 +2,8 @@ cmake_minimum_required(VERSION 3.5.1)
|
||||||
|
|
||||||
project(HelloWorld C CXX)
|
project(HelloWorld C CXX)
|
||||||
|
|
||||||
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
|
||||||
|
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
||||||
|
|
||||||
find_package(Threads REQUIRED)
|
find_package(Threads REQUIRED)
|
||||||
|
|
||||||
|
@ -69,4 +68,4 @@ foreach(_target
|
||||||
${_REFLECTION}
|
${_REFLECTION}
|
||||||
${_GRPC_GRPCPP}
|
${_GRPC_GRPCPP}
|
||||||
${_PROTOBUF_LIBPROTOBUF})
|
${_PROTOBUF_LIBPROTOBUF})
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|
|
@ -15,8 +15,10 @@
|
||||||
*/
|
*/
|
||||||
#include <grpcpp/grpcpp.h>
|
#include <grpcpp/grpcpp.h>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <vector>
|
||||||
#include "serving/ms_service.grpc.pb.h"
|
#include <string>
|
||||||
|
#include <fstream>
|
||||||
|
#include "./ms_service.grpc.pb.h"
|
||||||
|
|
||||||
using grpc::Channel;
|
using grpc::Channel;
|
||||||
using grpc::ClientContext;
|
using grpc::ClientContext;
|
||||||
|
@ -27,26 +29,214 @@ using ms_serving::PredictRequest;
|
||||||
using ms_serving::Tensor;
|
using ms_serving::Tensor;
|
||||||
using ms_serving::TensorShape;
|
using ms_serving::TensorShape;
|
||||||
|
|
||||||
|
enum TypeId : int {
|
||||||
|
kTypeUnknown = 0,
|
||||||
|
kMetaTypeBegin = kTypeUnknown,
|
||||||
|
kMetaTypeType, // Type
|
||||||
|
kMetaTypeAnything,
|
||||||
|
kMetaTypeObject,
|
||||||
|
kMetaTypeTypeType, // TypeType
|
||||||
|
kMetaTypeProblem,
|
||||||
|
kMetaTypeExternal,
|
||||||
|
kMetaTypeNone,
|
||||||
|
kMetaTypeNull,
|
||||||
|
kMetaTypeEllipsis,
|
||||||
|
kMetaTypeEnd,
|
||||||
|
//
|
||||||
|
// Object types
|
||||||
|
//
|
||||||
|
kObjectTypeBegin = kMetaTypeEnd,
|
||||||
|
kObjectTypeNumber,
|
||||||
|
kObjectTypeString,
|
||||||
|
kObjectTypeList,
|
||||||
|
kObjectTypeTuple,
|
||||||
|
kObjectTypeSlice,
|
||||||
|
kObjectTypeKeyword,
|
||||||
|
kObjectTypeTensorType,
|
||||||
|
kObjectTypeClass,
|
||||||
|
kObjectTypeDictionary,
|
||||||
|
kObjectTypeFunction,
|
||||||
|
kObjectTypeJTagged,
|
||||||
|
kObjectTypeSymbolicKeyType,
|
||||||
|
kObjectTypeEnvType,
|
||||||
|
kObjectTypeRefKey,
|
||||||
|
kObjectTypeRef,
|
||||||
|
kObjectTypeEnd,
|
||||||
|
//
|
||||||
|
// Number Types
|
||||||
|
//
|
||||||
|
kNumberTypeBegin = kObjectTypeEnd,
|
||||||
|
kNumberTypeBool,
|
||||||
|
kNumberTypeInt,
|
||||||
|
kNumberTypeInt8,
|
||||||
|
kNumberTypeInt16,
|
||||||
|
kNumberTypeInt32,
|
||||||
|
kNumberTypeInt64,
|
||||||
|
kNumberTypeUInt,
|
||||||
|
kNumberTypeUInt8,
|
||||||
|
kNumberTypeUInt16,
|
||||||
|
kNumberTypeUInt32,
|
||||||
|
kNumberTypeUInt64,
|
||||||
|
kNumberTypeFloat,
|
||||||
|
kNumberTypeFloat16,
|
||||||
|
kNumberTypeFloat32,
|
||||||
|
kNumberTypeFloat64,
|
||||||
|
kNumberTypeEnd
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string RealPath(const char *path) {
|
||||||
|
if (path == nullptr) {
|
||||||
|
std::cout << "path is nullptr";
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
if ((strlen(path)) >= PATH_MAX) {
|
||||||
|
std::cout << "path is too long";
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<char> resolvedPath(new (std::nothrow) char[PATH_MAX]{0});
|
||||||
|
if (resolvedPath == nullptr) {
|
||||||
|
std::cout << "new resolvedPath failed";
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ret = realpath(path, resolvedPath.get());
|
||||||
|
if (ret == nullptr) {
|
||||||
|
std::cout << "realpath failed";
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
return resolvedPath.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
char *ReadFile(const char *file, size_t *size) {
|
||||||
|
if (file == nullptr) {
|
||||||
|
std::cout << "file is nullptr" << std::endl;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (size == nullptr) {
|
||||||
|
std::cout << "size should not be nullptr" << std::endl;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::ifstream ifs(RealPath(file));
|
||||||
|
if (!ifs.good()) {
|
||||||
|
std::cout << "file: " << file << "is not exist";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ifs.is_open()) {
|
||||||
|
std::cout << "file: " << file << "open failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
ifs.seekg(0, std::ios::end);
|
||||||
|
*size = ifs.tellg();
|
||||||
|
std::unique_ptr<char> buf(new (std::nothrow) char[*size]);
|
||||||
|
if (buf == nullptr) {
|
||||||
|
std::cout << "malloc buf failed, file: " << file;
|
||||||
|
ifs.close();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
ifs.seekg(0, std::ios::beg);
|
||||||
|
ifs.read(buf.get(), *size);
|
||||||
|
ifs.close();
|
||||||
|
|
||||||
|
return buf.release();
|
||||||
|
}
|
||||||
|
const std::map<TypeId, ms_serving::DataType> id2type_map{
|
||||||
|
{TypeId::kNumberTypeBegin, ms_serving::MS_UNKNOWN}, {TypeId::kNumberTypeBool, ms_serving::MS_BOOL},
|
||||||
|
{TypeId::kNumberTypeInt8, ms_serving::MS_INT8}, {TypeId::kNumberTypeUInt8, ms_serving::MS_UINT8},
|
||||||
|
{TypeId::kNumberTypeInt16, ms_serving::MS_INT16}, {TypeId::kNumberTypeUInt16, ms_serving::MS_UINT16},
|
||||||
|
{TypeId::kNumberTypeInt32, ms_serving::MS_INT32}, {TypeId::kNumberTypeUInt32, ms_serving::MS_UINT32},
|
||||||
|
{TypeId::kNumberTypeInt64, ms_serving::MS_INT64}, {TypeId::kNumberTypeUInt64, ms_serving::MS_UINT64},
|
||||||
|
{TypeId::kNumberTypeFloat16, ms_serving::MS_FLOAT16}, {TypeId::kNumberTypeFloat32, ms_serving::MS_FLOAT32},
|
||||||
|
{TypeId::kNumberTypeFloat64, ms_serving::MS_FLOAT64},
|
||||||
|
};
|
||||||
|
|
||||||
|
int WriteFile(const void *buf, size_t size) {
|
||||||
|
auto fd = fopen("output.json", "a+");
|
||||||
|
if (fd == NULL) {
|
||||||
|
std::cout << "fd is null and open file fail" << std::endl;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
fwrite(buf, size, 1, fd);
|
||||||
|
fclose(fd);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
PredictRequest ReadBertInput() {
|
||||||
|
size_t size;
|
||||||
|
auto buf = ReadFile("input206.json", &size);
|
||||||
|
if (buf == nullptr) {
|
||||||
|
std::cout << "read file failed" << std::endl;
|
||||||
|
return PredictRequest();
|
||||||
|
}
|
||||||
|
PredictRequest request;
|
||||||
|
auto cur = buf;
|
||||||
|
while (size > 0) {
|
||||||
|
if (request.data_size() == 4) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Tensor data;
|
||||||
|
TensorShape shape;
|
||||||
|
// set type
|
||||||
|
int type = *(reinterpret_cast<int *>(cur));
|
||||||
|
cur = cur + sizeof(int);
|
||||||
|
size = size - sizeof(int);
|
||||||
|
ms_serving::DataType dataType = id2type_map.at(TypeId(type));
|
||||||
|
data.set_tensor_type(dataType);
|
||||||
|
|
||||||
|
// set shape
|
||||||
|
size_t dims = *(reinterpret_cast<size_t *>(cur));
|
||||||
|
cur = cur + sizeof(size_t);
|
||||||
|
size = size - sizeof(size_t);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < dims; i++) {
|
||||||
|
int dim = *(reinterpret_cast<int *>(cur));
|
||||||
|
shape.add_dims(dim);
|
||||||
|
cur = cur + sizeof(int);
|
||||||
|
size = size - sizeof(int);
|
||||||
|
}
|
||||||
|
*data.mutable_tensor_shape() = shape;
|
||||||
|
|
||||||
|
// set data
|
||||||
|
size_t data_len = *(reinterpret_cast<size_t *>(cur));
|
||||||
|
cur = cur + sizeof(size_t);
|
||||||
|
size = size - sizeof(size_t);
|
||||||
|
data.set_data(cur, data_len);
|
||||||
|
cur = cur + data_len;
|
||||||
|
size = size - data_len;
|
||||||
|
*request.add_data() = data;
|
||||||
|
}
|
||||||
|
return request;
|
||||||
|
}
|
||||||
|
|
||||||
class MSClient {
|
class MSClient {
|
||||||
public:
|
public:
|
||||||
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
|
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
|
||||||
|
|
||||||
std::string Predict(const std::string &user) {
|
std::string Predict(const std::string &type) {
|
||||||
// Data we are sending to the server.
|
// Data we are sending to the server.
|
||||||
PredictRequest request;
|
PredictRequest request;
|
||||||
Tensor data;
|
if (type == "add") {
|
||||||
TensorShape shape;
|
Tensor data;
|
||||||
shape.add_dims(1);
|
TensorShape shape;
|
||||||
shape.add_dims(1);
|
shape.add_dims(1);
|
||||||
shape.add_dims(2);
|
shape.add_dims(1);
|
||||||
shape.add_dims(2);
|
shape.add_dims(2);
|
||||||
*data.mutable_tensor_shape() = shape;
|
shape.add_dims(2);
|
||||||
data.set_tensor_type(ms_serving::MS_FLOAT32);
|
*data.mutable_tensor_shape() = shape;
|
||||||
vector<float> input_data{1.1, 2.1, 3.1, 4.1};
|
data.set_tensor_type(ms_serving::MS_FLOAT32);
|
||||||
data.set_data(input_data.data(), input_data.size());
|
std::vector<float> input_data{1.1, 2.1, 3.1, 4.1};
|
||||||
*request.add_data() = data;
|
data.set_data(input_data.data(), input_data.size());
|
||||||
*request.add_data() = data;
|
*request.add_data() = data;
|
||||||
|
*request.add_data() = data;
|
||||||
|
} else if (type == "bert") {
|
||||||
|
request = ReadBertInput();
|
||||||
|
} else {
|
||||||
|
std::cout << "type only support bert or add, but input is " << type << std::endl;
|
||||||
|
}
|
||||||
|
std::cout << "intput tensor size is " << request.data_size() << std::endl;
|
||||||
// Container for the data we expect from the server.
|
// Container for the data we expect from the server.
|
||||||
PredictReply reply;
|
PredictReply reply;
|
||||||
|
|
||||||
|
@ -57,6 +247,12 @@ class MSClient {
|
||||||
// The actual RPC.
|
// The actual RPC.
|
||||||
Status status = stub_->Predict(&context, request, &reply);
|
Status status = stub_->Predict(&context, request, &reply);
|
||||||
|
|
||||||
|
for (int i = 0; i < reply.result_size(); i++) {
|
||||||
|
WriteFile(reply.result(i).data().data(), reply.result(i).data().size());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout << "the return result size is " << reply.result_size() << std::endl;
|
||||||
|
|
||||||
// Act upon its status.
|
// Act upon its status.
|
||||||
if (status.ok()) {
|
if (status.ok()) {
|
||||||
return "RPC OK";
|
return "RPC OK";
|
||||||
|
@ -77,28 +273,50 @@ int main(int argc, char **argv) {
|
||||||
// We indicate that the channel isn't authenticated (use of
|
// We indicate that the channel isn't authenticated (use of
|
||||||
// InsecureChannelCredentials()).
|
// InsecureChannelCredentials()).
|
||||||
std::string target_str;
|
std::string target_str;
|
||||||
std::string arg_str("--target");
|
std::string arg_target_str("--target");
|
||||||
if (argc > 1) {
|
std::string type;
|
||||||
std::string arg_val = argv[1];
|
std::string arg_type_str("--type");
|
||||||
size_t start_pos = arg_val.find(arg_str);
|
if (argc > 2) {
|
||||||
if (start_pos != std::string::npos) {
|
{
|
||||||
start_pos += arg_str.size();
|
// parse target
|
||||||
if (arg_val[start_pos] == '=') {
|
std::string arg_val = argv[1];
|
||||||
target_str = arg_val.substr(start_pos + 1);
|
size_t start_pos = arg_val.find(arg_target_str);
|
||||||
|
if (start_pos != std::string::npos) {
|
||||||
|
start_pos += arg_target_str.size();
|
||||||
|
if (arg_val[start_pos] == '=') {
|
||||||
|
target_str = arg_val.substr(start_pos + 1);
|
||||||
|
} else {
|
||||||
|
std::cout << "The only correct argument syntax is --target=" << std::endl;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
std::cout << "The only correct argument syntax is --target=" << std::endl;
|
target_str = "localhost:5500";
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
std::cout << "The only acceptable argument is --target=" << std::endl;
|
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// parse type
|
||||||
|
std::string arg_val2 = argv[2];
|
||||||
|
size_t start_pos = arg_val2.find(arg_type_str);
|
||||||
|
if (start_pos != std::string::npos) {
|
||||||
|
start_pos += arg_type_str.size();
|
||||||
|
if (arg_val2[start_pos] == '=') {
|
||||||
|
type = arg_val2.substr(start_pos + 1);
|
||||||
|
} else {
|
||||||
|
std::cout << "The only correct argument syntax is --target=" << std::endl;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
type = "add";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
target_str = "localhost:85010";
|
target_str = "localhost:5500";
|
||||||
|
type = "add";
|
||||||
}
|
}
|
||||||
MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()));
|
MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()));
|
||||||
string request;
|
std::string reply = client.Predict(type);
|
||||||
string reply = client.Predict(request);
|
|
||||||
std::cout << "client received: " << reply << std::endl;
|
std::cout << "client received: " << reply << std::endl;
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
#include "serving/ms_service.grpc.pb.h"
|
#include "./ms_service.grpc.pb.h"
|
||||||
|
|
||||||
using grpc::Server;
|
using grpc::Server;
|
||||||
using grpc::ServerBuilder;
|
using grpc::ServerBuilder;
|
||||||
|
@ -31,7 +31,7 @@ using ms_serving::PredictRequest;
|
||||||
// Logic and data behind the server's behavior.
|
// Logic and data behind the server's behavior.
|
||||||
class MSServiceImpl final : public MSService::Service {
|
class MSServiceImpl final : public MSService::Service {
|
||||||
Status Predict(ServerContext *context, const PredictRequest *request, PredictReply *reply) override {
|
Status Predict(ServerContext *context, const PredictRequest *request, PredictReply *reply) override {
|
||||||
cout << "server eval" << endl;
|
std::cout << "server eval" << std::endl;
|
||||||
return Status::OK;
|
return Status::OK;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue