forked from mindspore-Ecosystem/mindspore
114 lines
3.7 KiB
C++
114 lines
3.7 KiB
C++
/**
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#include <grpcpp/grpcpp.h>
|
|
#include <iostream>
|
|
#include <vector>
|
|
#include <string>
|
|
#include <fstream>
|
|
#include "./ms_service.grpc.pb.h"
|
|
|
|
using grpc::Channel;
|
|
using grpc::ClientContext;
|
|
using grpc::Status;
|
|
using ms_serving::MSService;
|
|
using ms_serving::PredictReply;
|
|
using ms_serving::PredictRequest;
|
|
using ms_serving::Tensor;
|
|
using ms_serving::TensorShape;
|
|
|
|
class MSClient {
|
|
public:
|
|
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
|
|
|
|
~MSClient() = default;
|
|
|
|
std::string Predict() {
|
|
// Data we are sending to the server.
|
|
PredictRequest request;
|
|
|
|
Tensor data;
|
|
TensorShape shape;
|
|
shape.add_dims(2);
|
|
shape.add_dims(2);
|
|
*data.mutable_tensor_shape() = shape;
|
|
data.set_tensor_type(ms_serving::MS_FLOAT32);
|
|
std::vector<float> input_data{1, 2, 3, 4};
|
|
data.set_data(input_data.data(), input_data.size() * sizeof(float));
|
|
*request.add_data() = data;
|
|
*request.add_data() = data;
|
|
std::cout << "intput tensor size is " << request.data_size() << std::endl;
|
|
// Container for the data we expect from the server.
|
|
PredictReply reply;
|
|
|
|
// Context for the client. It could be used to convey extra information to
|
|
// the server and/or tweak certain RPC behaviors.
|
|
ClientContext context;
|
|
|
|
// The actual RPC.
|
|
Status status = stub_->Predict(&context, request, &reply);
|
|
std::cout << "Compute [[1, 2], [3, 4]] + [[1, 2], [3, 4]]" << std::endl;
|
|
|
|
// Act upon its status.
|
|
if (status.ok()) {
|
|
std::cout << "Add result is";
|
|
for (size_t i = 0; i < reply.result(0).data().size() / sizeof(float); i++) {
|
|
std::cout << " " << (reinterpret_cast<const float *>(reply.mutable_result(0)->mutable_data()->data()))[i];
|
|
}
|
|
std::cout << std::endl;
|
|
return "RPC OK";
|
|
} else {
|
|
std::cout << status.error_code() << ": " << status.error_message() << std::endl;
|
|
return "RPC failed";
|
|
}
|
|
}
|
|
|
|
private:
|
|
std::unique_ptr<MSService::Stub> stub_;
|
|
};
|
|
|
|
int main(int argc, char **argv) {
|
|
// Instantiate the client. It requires a channel, out of which the actual RPCs
|
|
// are created. This channel models a connection to an endpoint specified by
|
|
// the argument "--target=" which is the only expected argument.
|
|
// We indicate that the channel isn't authenticated (use of
|
|
// InsecureChannelCredentials()).
|
|
std::string target_str;
|
|
std::string arg_target_str("--target");
|
|
if (argc > 1) {
|
|
// parse target
|
|
std::string arg_val = argv[1];
|
|
size_t start_pos = arg_val.find(arg_target_str);
|
|
if (start_pos != std::string::npos) {
|
|
start_pos += arg_target_str.size();
|
|
if (start_pos < arg_val.size() && arg_val[start_pos] == '=') {
|
|
target_str = arg_val.substr(start_pos + 1);
|
|
} else {
|
|
std::cout << "The only correct argument syntax is --target=" << std::endl;
|
|
return 0;
|
|
}
|
|
} else {
|
|
target_str = "localhost:5500";
|
|
}
|
|
} else {
|
|
target_str = "localhost:5500";
|
|
}
|
|
MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()));
|
|
std::string reply = client.Predict();
|
|
std::cout << "client received: " << reply << std::endl;
|
|
|
|
return 0;
|
|
}
|