forked from mindspore-Ecosystem/mindspore
!3323 restructure client example
Merge pull request !3323 from hexia/restructure_client_example
This commit is contained in:
commit
49053e7f83
|
@ -106,6 +106,7 @@ endif() # NOT ENABLE_ACL
|
|||
|
||||
if (ENABLE_SERVING)
|
||||
add_subdirectory(serving)
|
||||
add_subdirectory(serving/example/cpp_client)
|
||||
endif()
|
||||
|
||||
if (NOT ENABLE_ACL)
|
||||
|
|
|
@ -1,318 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: ms_service.proto
|
||||
|
||||
from google.protobuf.internal import enum_type_wrapper
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import message as _message
|
||||
from google.protobuf import reflection as _reflection
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor.FileDescriptor(
|
||||
name='ms_service.proto',
|
||||
package='ms_serving',
|
||||
syntax='proto3',
|
||||
serialized_options=None,
|
||||
serialized_pb=b'\n\x10ms_service.proto\x12\nms_serving\"2\n\x0ePredictRequest\x12 \n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x12.ms_serving.Tensor\"2\n\x0cPredictReply\x12\"\n\x06result\x18\x01 \x03(\x0b\x32\x12.ms_serving.Tensor\"\x1b\n\x0bTensorShape\x12\x0c\n\x04\x64ims\x18\x01 \x03(\x03\"p\n\x06Tensor\x12-\n\x0ctensor_shape\x18\x01 \x01(\x0b\x32\x17.ms_serving.TensorShape\x12)\n\x0btensor_type\x18\x02 \x01(\x0e\x32\x14.ms_serving.DataType\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c*\xc9\x01\n\x08\x44\x61taType\x12\x0e\n\nMS_UNKNOWN\x10\x00\x12\x0b\n\x07MS_BOOL\x10\x01\x12\x0b\n\x07MS_INT8\x10\x02\x12\x0c\n\x08MS_UINT8\x10\x03\x12\x0c\n\x08MS_INT16\x10\x04\x12\r\n\tMS_UINT16\x10\x05\x12\x0c\n\x08MS_INT32\x10\x06\x12\r\n\tMS_UINT32\x10\x07\x12\x0c\n\x08MS_INT64\x10\x08\x12\r\n\tMS_UINT64\x10\t\x12\x0e\n\nMS_FLOAT16\x10\n\x12\x0e\n\nMS_FLOAT32\x10\x0b\x12\x0e\n\nMS_FLOAT64\x10\x0c\x32\x8e\x01\n\tMSService\x12\x41\n\x07Predict\x12\x1a.ms_serving.PredictRequest\x1a\x18.ms_serving.PredictReply\"\x00\x12>\n\x04Test\x12\x1a.ms_serving.PredictRequest\x1a\x18.ms_serving.PredictReply\"\x00\x62\x06proto3'
|
||||
)
|
||||
|
||||
_DATATYPE = _descriptor.EnumDescriptor(
|
||||
name='DataType',
|
||||
full_name='ms_serving.DataType',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
values=[
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='MS_UNKNOWN', index=0, number=0,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='MS_BOOL', index=1, number=1,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='MS_INT8', index=2, number=2,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='MS_UINT8', index=3, number=3,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='MS_INT16', index=4, number=4,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='MS_UINT16', index=5, number=5,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='MS_INT32', index=6, number=6,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='MS_UINT32', index=7, number=7,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='MS_INT64', index=8, number=8,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='MS_UINT64', index=9, number=9,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='MS_FLOAT16', index=10, number=10,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='MS_FLOAT32', index=11, number=11,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='MS_FLOAT64', index=12, number=12,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
],
|
||||
containing_type=None,
|
||||
serialized_options=None,
|
||||
serialized_start=280,
|
||||
serialized_end=481,
|
||||
)
|
||||
_sym_db.RegisterEnumDescriptor(_DATATYPE)
|
||||
|
||||
DataType = enum_type_wrapper.EnumTypeWrapper(_DATATYPE)
|
||||
MS_UNKNOWN = 0
|
||||
MS_BOOL = 1
|
||||
MS_INT8 = 2
|
||||
MS_UINT8 = 3
|
||||
MS_INT16 = 4
|
||||
MS_UINT16 = 5
|
||||
MS_INT32 = 6
|
||||
MS_UINT32 = 7
|
||||
MS_INT64 = 8
|
||||
MS_UINT64 = 9
|
||||
MS_FLOAT16 = 10
|
||||
MS_FLOAT32 = 11
|
||||
MS_FLOAT64 = 12
|
||||
|
||||
|
||||
|
||||
_PREDICTREQUEST = _descriptor.Descriptor(
|
||||
name='PredictRequest',
|
||||
full_name='ms_serving.PredictRequest',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='data', full_name='ms_serving.PredictRequest.data', index=0,
|
||||
number=1, type=11, cpp_type=10, label=3,
|
||||
has_default_value=False, default_value=[],
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=32,
|
||||
serialized_end=82,
|
||||
)
|
||||
|
||||
|
||||
_PREDICTREPLY = _descriptor.Descriptor(
|
||||
name='PredictReply',
|
||||
full_name='ms_serving.PredictReply',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='result', full_name='ms_serving.PredictReply.result', index=0,
|
||||
number=1, type=11, cpp_type=10, label=3,
|
||||
has_default_value=False, default_value=[],
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=84,
|
||||
serialized_end=134,
|
||||
)
|
||||
|
||||
|
||||
_TENSORSHAPE = _descriptor.Descriptor(
|
||||
name='TensorShape',
|
||||
full_name='ms_serving.TensorShape',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='dims', full_name='ms_serving.TensorShape.dims', index=0,
|
||||
number=1, type=3, cpp_type=2, label=3,
|
||||
has_default_value=False, default_value=[],
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=136,
|
||||
serialized_end=163,
|
||||
)
|
||||
|
||||
|
||||
_TENSOR = _descriptor.Descriptor(
|
||||
name='Tensor',
|
||||
full_name='ms_serving.Tensor',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='tensor_shape', full_name='ms_serving.Tensor.tensor_shape', index=0,
|
||||
number=1, type=11, cpp_type=10, label=1,
|
||||
has_default_value=False, default_value=None,
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR),
|
||||
_descriptor.FieldDescriptor(
|
||||
name='tensor_type', full_name='ms_serving.Tensor.tensor_type', index=1,
|
||||
number=2, type=14, cpp_type=8, label=1,
|
||||
has_default_value=False, default_value=0,
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR),
|
||||
_descriptor.FieldDescriptor(
|
||||
name='data', full_name='ms_serving.Tensor.data', index=2,
|
||||
number=3, type=12, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"",
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, file=DESCRIPTOR),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=165,
|
||||
serialized_end=277,
|
||||
)
|
||||
|
||||
_PREDICTREQUEST.fields_by_name['data'].message_type = _TENSOR
|
||||
_PREDICTREPLY.fields_by_name['result'].message_type = _TENSOR
|
||||
_TENSOR.fields_by_name['tensor_shape'].message_type = _TENSORSHAPE
|
||||
_TENSOR.fields_by_name['tensor_type'].enum_type = _DATATYPE
|
||||
DESCRIPTOR.message_types_by_name['PredictRequest'] = _PREDICTREQUEST
|
||||
DESCRIPTOR.message_types_by_name['PredictReply'] = _PREDICTREPLY
|
||||
DESCRIPTOR.message_types_by_name['TensorShape'] = _TENSORSHAPE
|
||||
DESCRIPTOR.message_types_by_name['Tensor'] = _TENSOR
|
||||
DESCRIPTOR.enum_types_by_name['DataType'] = _DATATYPE
|
||||
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
|
||||
|
||||
PredictRequest = _reflection.GeneratedProtocolMessageType('PredictRequest', (_message.Message,), {
|
||||
'DESCRIPTOR' : _PREDICTREQUEST,
|
||||
'__module__' : 'ms_service_pb2'
|
||||
# @@protoc_insertion_point(class_scope:ms_serving.PredictRequest)
|
||||
})
|
||||
_sym_db.RegisterMessage(PredictRequest)
|
||||
|
||||
PredictReply = _reflection.GeneratedProtocolMessageType('PredictReply', (_message.Message,), {
|
||||
'DESCRIPTOR' : _PREDICTREPLY,
|
||||
'__module__' : 'ms_service_pb2'
|
||||
# @@protoc_insertion_point(class_scope:ms_serving.PredictReply)
|
||||
})
|
||||
_sym_db.RegisterMessage(PredictReply)
|
||||
|
||||
TensorShape = _reflection.GeneratedProtocolMessageType('TensorShape', (_message.Message,), {
|
||||
'DESCRIPTOR' : _TENSORSHAPE,
|
||||
'__module__' : 'ms_service_pb2'
|
||||
# @@protoc_insertion_point(class_scope:ms_serving.TensorShape)
|
||||
})
|
||||
_sym_db.RegisterMessage(TensorShape)
|
||||
|
||||
Tensor = _reflection.GeneratedProtocolMessageType('Tensor', (_message.Message,), {
|
||||
'DESCRIPTOR' : _TENSOR,
|
||||
'__module__' : 'ms_service_pb2'
|
||||
# @@protoc_insertion_point(class_scope:ms_serving.Tensor)
|
||||
})
|
||||
_sym_db.RegisterMessage(Tensor)
|
||||
|
||||
|
||||
|
||||
_MSSERVICE = _descriptor.ServiceDescriptor(
|
||||
name='MSService',
|
||||
full_name='ms_serving.MSService',
|
||||
file=DESCRIPTOR,
|
||||
index=0,
|
||||
serialized_options=None,
|
||||
serialized_start=484,
|
||||
serialized_end=626,
|
||||
methods=[
|
||||
_descriptor.MethodDescriptor(
|
||||
name='Predict',
|
||||
full_name='ms_serving.MSService.Predict',
|
||||
index=0,
|
||||
containing_service=None,
|
||||
input_type=_PREDICTREQUEST,
|
||||
output_type=_PREDICTREPLY,
|
||||
serialized_options=None,
|
||||
),
|
||||
_descriptor.MethodDescriptor(
|
||||
name='Test',
|
||||
full_name='ms_serving.MSService.Test',
|
||||
index=1,
|
||||
containing_service=None,
|
||||
input_type=_PREDICTREQUEST,
|
||||
output_type=_PREDICTREPLY,
|
||||
serialized_options=None,
|
||||
),
|
||||
])
|
||||
_sym_db.RegisterServiceDescriptor(_MSSERVICE)
|
||||
|
||||
DESCRIPTOR.services_by_name['MSService'] = _MSSERVICE
|
||||
|
||||
# @@protoc_insertion_point(module_scope)
|
|
@ -1,96 +0,0 @@
|
|||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
import grpc
|
||||
|
||||
import ms_service_pb2 as ms__service__pb2
|
||||
|
||||
|
||||
class MSServiceStub(object):
|
||||
"""Missing associated documentation comment in .proto file"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.Predict = channel.unary_unary(
|
||||
'/ms_serving.MSService/Predict',
|
||||
request_serializer=ms__service__pb2.PredictRequest.SerializeToString,
|
||||
response_deserializer=ms__service__pb2.PredictReply.FromString,
|
||||
)
|
||||
self.Test = channel.unary_unary(
|
||||
'/ms_serving.MSService/Test',
|
||||
request_serializer=ms__service__pb2.PredictRequest.SerializeToString,
|
||||
response_deserializer=ms__service__pb2.PredictReply.FromString,
|
||||
)
|
||||
|
||||
|
||||
class MSServiceServicer(object):
|
||||
"""Missing associated documentation comment in .proto file"""
|
||||
|
||||
def Predict(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Test(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_MSServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'Predict': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Predict,
|
||||
request_deserializer=ms__service__pb2.PredictRequest.FromString,
|
||||
response_serializer=ms__service__pb2.PredictReply.SerializeToString,
|
||||
),
|
||||
'Test': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Test,
|
||||
request_deserializer=ms__service__pb2.PredictRequest.FromString,
|
||||
response_serializer=ms__service__pb2.PredictReply.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'ms_serving.MSService', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class MSService(object):
|
||||
"""Missing associated documentation comment in .proto file"""
|
||||
|
||||
@staticmethod
|
||||
def Predict(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/ms_serving.MSService/Predict',
|
||||
ms__service__pb2.PredictRequest.SerializeToString,
|
||||
ms__service__pb2.PredictReply.FromString,
|
||||
options, channel_credentials,
|
||||
call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def Test(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/ms_serving.MSService/Test',
|
||||
ms__service__pb2.PredictRequest.SerializeToString,
|
||||
ms__service__pb2.PredictReply.FromString,
|
||||
options, channel_credentials,
|
||||
call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
@ -1,67 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <grpcpp/grpcpp.h>
|
||||
#include <grpcpp/health_check_service_interface.h>
|
||||
#include <grpcpp/ext/proto_server_reflection_plugin.h>
|
||||
#include <iostream>
|
||||
|
||||
#include "./ms_service.grpc.pb.h"
|
||||
|
||||
using grpc::Server;
|
||||
using grpc::ServerBuilder;
|
||||
using grpc::ServerContext;
|
||||
using grpc::Status;
|
||||
using ms_serving::MSService;
|
||||
using ms_serving::PredictReply;
|
||||
using ms_serving::PredictRequest;
|
||||
|
||||
// Logic and data behind the server's behavior.
|
||||
class MSServiceImpl final : public MSService::Service {
|
||||
Status Predict(ServerContext *context, const PredictRequest *request, PredictReply *reply) override {
|
||||
std::cout << "server eval" << std::endl;
|
||||
return Status::OK;
|
||||
}
|
||||
};
|
||||
|
||||
void RunServer() {
|
||||
std::string server_address("0.0.0.0:50051");
|
||||
MSServiceImpl service;
|
||||
|
||||
grpc::EnableDefaultHealthCheckService(true);
|
||||
grpc::reflection::InitProtoReflectionServerBuilderPlugin();
|
||||
auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0);
|
||||
|
||||
ServerBuilder builder;
|
||||
builder.SetOption(std::move(option));
|
||||
// Listen on the given address without any authentication mechanism.
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
// Register "service" as the instance through which we'll communicate with
|
||||
// clients. In this case it corresponds to an *synchronous* service.
|
||||
builder.RegisterService(&service);
|
||||
// Finally assemble the server.
|
||||
std::unique_ptr<Server> server(builder.BuildAndStart());
|
||||
std::cout << "Server listening on " << server_address << std::endl;
|
||||
|
||||
// Wait for the server to shutdown. Note that some other thread must be
|
||||
// responsible for shutting down the server for this call to ever return.
|
||||
server->Wait();
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
RunServer();
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
cmake_minimum_required(VERSION 3.5.1)
|
||||
|
||||
project(HelloWorld C CXX)
|
||||
project(MSClient C CXX)
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
|
||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
||||
|
@ -12,17 +12,33 @@ find_package(Threads REQUIRED)
|
|||
|
||||
# Find Protobuf installation
|
||||
# Looks for protobuf-config.cmake file installed by Protobuf's cmake installation.
|
||||
set(protobuf_MODULE_COMPATIBLE TRUE)
|
||||
find_package(Protobuf CONFIG REQUIRED)
|
||||
message(STATUS "Using protobuf ${protobuf_VERSION}")
|
||||
option(GRPC_PATH "set grpc path")
|
||||
if(GRPC_PATH)
|
||||
set(CMAKE_PREFIX_PATH ${GRPC_PATH})
|
||||
set(protobuf_MODULE_COMPATIBLE TRUE)
|
||||
find_package(Protobuf CONFIG REQUIRED)
|
||||
message(STATUS "Using protobuf ${protobuf_VERSION}, CMAKE_PREFIX_PATH : ${CMAKE_PREFIX_PATH}")
|
||||
elseif(NOT GRPC_PATH)
|
||||
if (EXISTS ${grpc_ROOT}/lib64)
|
||||
set(gRPC_DIR "${grpc_ROOT}/lib64/cmake/grpc")
|
||||
elseif(EXISTS ${grpc_ROOT}/lib)
|
||||
set(gRPC_DIR "${grpc_ROOT}/lib/cmake/grpc")
|
||||
endif()
|
||||
add_library(protobuf::libprotobuf ALIAS protobuf::protobuf)
|
||||
add_executable(protobuf::libprotoc ALIAS protobuf::protoc)
|
||||
message(STATUS "serving using grpc_DIR : " ${gRPC_DIR})
|
||||
elseif(NOT gRPC_DIR AND NOT GRPC_PATH)
|
||||
message("please check gRPC. If the client is compiled separately,you can use the command: cmake -D GRPC_PATH=xxx")
|
||||
message("XXX is the gRPC installation path")
|
||||
endif()
|
||||
|
||||
set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf)
|
||||
set(_REFLECTION gRPC::grpc++_reflection)
|
||||
if (CMAKE_CROSSCOMPILING)
|
||||
find_program(_PROTOBUF_PROTOC protoc)
|
||||
else ()
|
||||
set(_PROTOBUF_PROTOC $<TARGET_FILE:protobuf::protoc>)
|
||||
endif ()
|
||||
if(CMAKE_CROSSCOMPILING)
|
||||
find_program(_PROTOBUF_PROTOC protoc)
|
||||
else()
|
||||
set(_PROTOBUF_PROTOC $<TARGET_FILE:protobuf::protoc>)
|
||||
endif()
|
||||
|
||||
# Find gRPC installation
|
||||
# Looks for gRPCConfig.cmake file installed by gRPC's cmake installation.
|
||||
|
@ -30,14 +46,14 @@ find_package(gRPC CONFIG REQUIRED)
|
|||
message(STATUS "Using gRPC ${gRPC_VERSION}")
|
||||
|
||||
set(_GRPC_GRPCPP gRPC::grpc++)
|
||||
if (CMAKE_CROSSCOMPILING)
|
||||
find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin)
|
||||
else ()
|
||||
set(_GRPC_CPP_PLUGIN_EXECUTABLE $<TARGET_FILE:gRPC::grpc_cpp_plugin>)
|
||||
endif ()
|
||||
if(CMAKE_CROSSCOMPILING)
|
||||
find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin)
|
||||
else()
|
||||
set(_GRPC_CPP_PLUGIN_EXECUTABLE $<TARGET_FILE:gRPC::grpc_cpp_plugin>)
|
||||
endif()
|
||||
|
||||
# Proto file
|
||||
get_filename_component(hw_proto "../ms_service.proto" ABSOLUTE)
|
||||
get_filename_component(hw_proto "../../ms_service.proto" ABSOLUTE)
|
||||
get_filename_component(hw_proto_path "${hw_proto}" PATH)
|
||||
|
||||
# Generated sources
|
||||
|
@ -59,13 +75,13 @@ add_custom_command(
|
|||
include_directories("${CMAKE_CURRENT_BINARY_DIR}")
|
||||
|
||||
# Targets greeter_[async_](client|server)
|
||||
foreach (_target
|
||||
ms_client ms_server)
|
||||
add_executable(${_target} "${_target}.cc"
|
||||
${hw_proto_srcs}
|
||||
${hw_grpc_srcs})
|
||||
target_link_libraries(${_target}
|
||||
${_REFLECTION}
|
||||
${_GRPC_GRPCPP}
|
||||
${_PROTOBUF_LIBPROTOBUF})
|
||||
endforeach ()
|
||||
foreach(_target
|
||||
ms_client)
|
||||
add_executable(${_target} "${_target}.cc"
|
||||
${hw_proto_srcs}
|
||||
${hw_grpc_srcs})
|
||||
target_link_libraries(${_target}
|
||||
${_REFLECTION}
|
||||
${_GRPC_GRPCPP}
|
||||
${_PROTOBUF_LIBPROTOBUF})
|
||||
endforeach()
|
|
@ -211,77 +211,12 @@ PredictRequest ReadBertInput() {
|
|||
return request;
|
||||
}
|
||||
|
||||
PredictRequest ReadLenetInput() {
|
||||
size_t size;
|
||||
auto buf = ReadFile("lenet_img.bin", &size);
|
||||
if (buf == nullptr) {
|
||||
std::cout << "read file failed" << std::endl;
|
||||
return PredictRequest();
|
||||
}
|
||||
PredictRequest request;
|
||||
auto cur = buf;
|
||||
if (size > 0) {
|
||||
Tensor data;
|
||||
TensorShape shape;
|
||||
// set type
|
||||
data.set_tensor_type(ms_serving::MS_FLOAT32);
|
||||
|
||||
// set shape
|
||||
shape.add_dims(size / sizeof(float));
|
||||
*data.mutable_tensor_shape() = shape;
|
||||
|
||||
// set data
|
||||
data.set_data(cur, size);
|
||||
*request.add_data() = data;
|
||||
}
|
||||
std::cout << "get input data size " << size << std::endl;
|
||||
return request;
|
||||
}
|
||||
|
||||
PredictRequest ReadOtherInput(const std::string &data_file) {
|
||||
size_t size;
|
||||
auto buf = ReadFile(data_file.c_str(), &size);
|
||||
if (buf == nullptr) {
|
||||
std::cout << "read file failed" << std::endl;
|
||||
return PredictRequest();
|
||||
}
|
||||
PredictRequest request;
|
||||
auto cur = buf;
|
||||
if (size > 0) {
|
||||
Tensor data;
|
||||
TensorShape shape;
|
||||
// set type
|
||||
data.set_tensor_type(ms_serving::MS_FLOAT32);
|
||||
|
||||
// set shape
|
||||
shape.add_dims(size / sizeof(float));
|
||||
*data.mutable_tensor_shape() = shape;
|
||||
|
||||
// set data
|
||||
data.set_data(cur, size);
|
||||
*request.add_data() = data;
|
||||
}
|
||||
std::cout << "get input data size " << size << std::endl;
|
||||
return request;
|
||||
}
|
||||
|
||||
template <class DT>
|
||||
void print_array_item(const DT *data, size_t size) {
|
||||
for (size_t i = 0; i < size && i < 100; i++) {
|
||||
std::cout << data[i] << '\t';
|
||||
if ((i + 1) % 10 == 0) {
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
class MSClient {
|
||||
public:
|
||||
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
|
||||
~MSClient() = default;
|
||||
|
||||
std::string Predict(const std::string &type, const std::string &data_file) {
|
||||
std::string Predict(const std::string &type) {
|
||||
// Data we are sending to the server.
|
||||
PredictRequest request;
|
||||
if (type == "add") {
|
||||
|
@ -299,10 +234,6 @@ class MSClient {
|
|||
*request.add_data() = data;
|
||||
} else if (type == "bert") {
|
||||
request = ReadBertInput();
|
||||
} else if (type == "lenet") {
|
||||
request = ReadLenetInput();
|
||||
} else if (type == "other") {
|
||||
request = ReadOtherInput(data_file);
|
||||
} else {
|
||||
std::cout << "type only support bert or add, but input is " << type << std::endl;
|
||||
}
|
||||
|
@ -325,20 +256,6 @@ class MSClient {
|
|||
|
||||
// Act upon its status.
|
||||
if (status.ok()) {
|
||||
for (size_t i = 0; i < reply.result_size(); i++) {
|
||||
auto result = reply.result(i);
|
||||
if (result.tensor_type() == ms_serving::DataType::MS_FLOAT32) {
|
||||
print_array_item(reinterpret_cast<const float *>(result.data().data()), result.data().size() / sizeof(float));
|
||||
} else if (result.tensor_type() == ms_serving::DataType::MS_INT32) {
|
||||
print_array_item(reinterpret_cast<const int32_t *>(result.data().data()),
|
||||
result.data().size() / sizeof(int32_t));
|
||||
} else if (result.tensor_type() == ms_serving::DataType::MS_UINT32) {
|
||||
print_array_item(reinterpret_cast<const uint32_t *>(result.data().data()),
|
||||
result.data().size() / sizeof(uint32_t));
|
||||
} else {
|
||||
std::cout << "output datatype " << result.tensor_type() << std::endl;
|
||||
}
|
||||
}
|
||||
return "RPC OK";
|
||||
} else {
|
||||
std::cout << status.error_code() << ": " << status.error_message() << std::endl;
|
||||
|
@ -360,8 +277,6 @@ int main(int argc, char **argv) {
|
|||
std::string arg_target_str("--target");
|
||||
std::string type;
|
||||
std::string arg_type_str("--type");
|
||||
std::string arg_data_str("--data");
|
||||
std::string data = "default_data.bin";
|
||||
if (argc > 2) {
|
||||
{
|
||||
// parse target
|
||||
|
@ -389,33 +304,19 @@ int main(int argc, char **argv) {
|
|||
if (arg_val2[start_pos] == '=') {
|
||||
type = arg_val2.substr(start_pos + 1);
|
||||
} else {
|
||||
std::cout << "The only correct argument syntax is --type=" << std::endl;
|
||||
std::cout << "The only correct argument syntax is --target=" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
} else {
|
||||
type = "add";
|
||||
}
|
||||
}
|
||||
if (argc > 3) {
|
||||
// parse type
|
||||
std::string arg_val3 = argv[3];
|
||||
size_t start_pos = arg_val3.find(arg_data_str);
|
||||
if (start_pos != std::string::npos) {
|
||||
start_pos += arg_data_str.size();
|
||||
if (arg_val3[start_pos] == '=') {
|
||||
data = arg_val3.substr(start_pos + 1);
|
||||
} else {
|
||||
std::cout << "The only correct argument syntax is --data=" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
target_str = "localhost:5500";
|
||||
type = "add";
|
||||
}
|
||||
MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()));
|
||||
std::string reply = client.Predict(type, data);
|
||||
std::string reply = client.Predict(type);
|
||||
std::cout << "client received: " << reply << std::endl;
|
||||
|
||||
return 0;
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import export
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
def construct(self, x_, y_):
|
||||
return self.add(x_, y_)
|
||||
|
||||
x = np.ones(4).astype(np.float32)
|
||||
y = np.ones(4).astype(np.float32)
|
||||
|
||||
def export_net():
|
||||
add = Net()
|
||||
output = add(Tensor(x), Tensor(y))
|
||||
export(add, Tensor(x), Tensor(y), file_name='tensor_add.pb', file_format='BINARY')
|
||||
print(x)
|
||||
print(y)
|
||||
print(output.asnumpy())
|
||||
|
||||
if __name__ == "__main__":
|
||||
export_net()
|
||||
|
|
@ -19,28 +19,25 @@ import ms_service_pb2_grpc
|
|||
|
||||
|
||||
def run():
|
||||
channel = grpc.insecure_channel('localhost:50051')
|
||||
channel = grpc.insecure_channel('localhost:5050')
|
||||
stub = ms_service_pb2_grpc.MSServiceStub(channel)
|
||||
# request = ms_service_pb2.EvalRequest()
|
||||
# request.name = 'haha'
|
||||
# response = stub.Eval(request)
|
||||
# print("ms client received: " + response.message)
|
||||
|
||||
request = ms_service_pb2.PredictRequest()
|
||||
request.data.tensor_shape.dims.extend([32, 1, 32, 32])
|
||||
request.data.tensor_type = ms_service_pb2.MS_FLOAT32
|
||||
request.data.data = (np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01).tobytes()
|
||||
|
||||
request.label.tensor_shape.dims.extend([32])
|
||||
request.label.tensor_type = ms_service_pb2.MS_INT32
|
||||
request.label.data = np.ones([32]).astype(np.int32).tobytes()
|
||||
|
||||
result = stub.Test(request)
|
||||
#result_np = np.frombuffer(result.result.data, dtype=np.float32).reshape(result.result.tensor_shape.dims)
|
||||
print("ms client test call received: ")
|
||||
#print(result_np)
|
||||
x = request.data.add()
|
||||
x.tensor_shape.dims.extend([4])
|
||||
x.tensor_type = ms_service_pb2.MS_FLOAT32
|
||||
x.data = (np.ones([4]).astype(np.float32)).tobytes()
|
||||
|
||||
y = request.data.add()
|
||||
y.tensor_shape.dims.extend([4])
|
||||
y.tensor_type = ms_service_pb2.MS_FLOAT32
|
||||
y.data = (np.ones([4]).astype(np.float32)).tobytes()
|
||||
|
||||
result = stub.Predict(request)
|
||||
print(result)
|
||||
result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims)
|
||||
print("ms client received: ")
|
||||
print(result_np)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
|
@ -1,57 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import grpc
|
||||
import numpy as np
|
||||
import ms_service_pb2
|
||||
import ms_service_pb2_grpc
|
||||
|
||||
|
||||
def run():
|
||||
channel = grpc.insecure_channel('localhost:50051')
|
||||
stub = ms_service_pb2_grpc.MSServiceStub(channel)
|
||||
# request = ms_service_pb2.PredictRequest()
|
||||
# request.name = 'haha'
|
||||
# response = stub.Eval(request)
|
||||
# print("ms client received: " + response.message)
|
||||
|
||||
request = ms_service_pb2.PredictRequest()
|
||||
request.data.tensor_shape.dims.extend([32, 1, 32, 32])
|
||||
request.data.tensor_type = ms_service_pb2.MS_FLOAT32
|
||||
request.data.data = (np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01).tobytes()
|
||||
|
||||
request.label.tensor_shape.dims.extend([32])
|
||||
request.label.tensor_type = ms_service_pb2.MS_INT32
|
||||
request.label.data = np.ones([32]).astype(np.int32).tobytes()
|
||||
|
||||
result = stub.Predict(request)
|
||||
#result_np = np.frombuffer(result.result.data, dtype=np.float32).reshape(result.result.tensor_shape.dims)
|
||||
print("ms client received: ")
|
||||
#print(result_np)
|
||||
|
||||
# future_list = []
|
||||
# times = 1000
|
||||
# for i in range(times):
|
||||
# async_future = stub.Eval.future(request)
|
||||
# future_list.append(async_future)
|
||||
# print("async call, future list add item " + str(i));
|
||||
#
|
||||
# for i in range(len(future_list)):
|
||||
# async_result = future_list[i].result()
|
||||
# print("ms client async get result of item " + str(i))
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
|
@ -1,55 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from concurrent import futures
|
||||
import time
|
||||
import grpc
|
||||
import numpy as np
|
||||
import ms_service_pb2
|
||||
import ms_service_pb2_grpc
|
||||
import test_cpu_lenet
|
||||
from mindspore import Tensor
|
||||
|
||||
class MSService(ms_service_pb2_grpc.MSServiceServicer):
|
||||
def Predict(self, request, context):
|
||||
request_data = request.data
|
||||
request_label = request.label
|
||||
|
||||
data_from_buffer = np.frombuffer(request_data.data, dtype=np.float32)
|
||||
data_from_buffer = data_from_buffer.reshape(request_data.tensor_shape.dims)
|
||||
data = Tensor(data_from_buffer)
|
||||
|
||||
label_from_buffer = np.frombuffer(request_label.data, dtype=np.int32)
|
||||
label_from_buffer = label_from_buffer.reshape(request_label.tensor_shape.dims)
|
||||
label = Tensor(label_from_buffer)
|
||||
|
||||
result = test_cpu_lenet.test_lenet(data, label)
|
||||
result_reply = ms_service_pb2.PredictReply()
|
||||
result_reply.result.tensor_shape.dims.extend(result.shape())
|
||||
result_reply.result.data = result.asnumpy().tobytes()
|
||||
return result_reply
|
||||
|
||||
def serve():
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
|
||||
ms_service_pb2_grpc.add_MSServiceServicer_to_server(MSService(), server)
|
||||
server.add_insecure_port('[::]:50051')
|
||||
server.start()
|
||||
try:
|
||||
while True:
|
||||
time.sleep(60*60*24) # one day in seconds
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
serve()
|
|
@ -1,91 +0,0 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.ops import operations as P
|
||||
import ms_service_pb2
|
||||
|
||||
|
||||
class LeNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(LeNet, self).__init__()
|
||||
self.relu = P.ReLU()
|
||||
self.batch_size = 32
|
||||
|
||||
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
|
||||
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.reshape = P.Reshape()
|
||||
self.fc1 = nn.Dense(400, 120)
|
||||
self.fc2 = nn.Dense(120, 84)
|
||||
self.fc3 = nn.Dense(84, 10)
|
||||
|
||||
def construct(self, input_x):
|
||||
output = self.conv1(input_x)
|
||||
output = self.relu(output)
|
||||
output = self.pool(output)
|
||||
output = self.conv2(output)
|
||||
output = self.relu(output)
|
||||
output = self.pool(output)
|
||||
output = self.reshape(output, (self.batch_size, -1))
|
||||
output = self.fc1(output)
|
||||
output = self.relu(output)
|
||||
output = self.fc2(output)
|
||||
output = self.relu(output)
|
||||
output = self.fc3(output)
|
||||
return output
|
||||
|
||||
|
||||
def train(net, data, label):
|
||||
learning_rate = 0.01
|
||||
momentum = 0.9
|
||||
|
||||
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
|
||||
train_network.set_train()
|
||||
res = train_network(data, label)
|
||||
print("+++++++++Loss+++++++++++++")
|
||||
print(res)
|
||||
print("+++++++++++++++++++++++++++")
|
||||
assert res
|
||||
return res
|
||||
|
||||
def test_lenet(data, label):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
net = LeNet()
|
||||
return train(net, data, label)
|
||||
|
||||
if __name__ == '__main__':
|
||||
tensor = ms_service_pb2.Tensor()
|
||||
tensor.tensor_shape.dim.extend([32, 1, 32, 32])
|
||||
# tensor.tensor_shape.dim.add() = 1
|
||||
# tensor.tensor_shape.dim.add() = 32
|
||||
# tensor.tensor_shape.dim.add() = 32
|
||||
tensor.tensor_type = ms_service_pb2.MS_FLOAT32
|
||||
tensor.data = np.ones([32, 1, 32, 32]).astype(np.float32).tobytes()
|
||||
|
||||
data_from_buffer = np.frombuffer(tensor.data, dtype=np.float32)
|
||||
print(tensor.tensor_shape.dim)
|
||||
data_from_buffer = data_from_buffer.reshape(tensor.tensor_shape.dim)
|
||||
print(data_from_buffer.shape)
|
||||
input_data = Tensor(data_from_buffer * 0.01)
|
||||
input_label = Tensor(np.ones([32]).astype(np.int32))
|
||||
test_lenet(input_data, input_label)
|
Loading…
Reference in New Issue