mindspore/serving/example/cpp_client/ms_client.cc

114 lines
3.7 KiB
C++

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <grpcpp/grpcpp.h>
#include <iostream>
#include <vector>
#include <string>
#include <fstream>
#include "./ms_service.grpc.pb.h"
using grpc::Channel;
using grpc::ClientContext;
using grpc::Status;
using ms_serving::MSService;
using ms_serving::PredictReply;
using ms_serving::PredictRequest;
using ms_serving::Tensor;
using ms_serving::TensorShape;
class MSClient {
public:
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
~MSClient() = default;
std::string Predict() {
// Data we are sending to the server.
PredictRequest request;
Tensor data;
TensorShape shape;
shape.add_dims(2);
shape.add_dims(2);
*data.mutable_tensor_shape() = shape;
data.set_tensor_type(ms_serving::MS_FLOAT32);
std::vector<float> input_data{1, 2, 3, 4};
data.set_data(input_data.data(), input_data.size() * sizeof(float));
*request.add_data() = data;
*request.add_data() = data;
std::cout << "intput tensor size is " << request.data_size() << std::endl;
// Container for the data we expect from the server.
PredictReply reply;
// Context for the client. It could be used to convey extra information to
// the server and/or tweak certain RPC behaviors.
ClientContext context;
// The actual RPC.
Status status = stub_->Predict(&context, request, &reply);
std::cout << "Compute [[1, 2], [3, 4]] + [[1, 2], [3, 4]]" << std::endl;
// Act upon its status.
if (status.ok()) {
std::cout << "Add result is";
for (size_t i = 0; i < reply.result(0).data().size() / sizeof(float); i++) {
std::cout << " " << (reinterpret_cast<const float *>(reply.mutable_result(0)->mutable_data()->data()))[i];
}
std::cout << std::endl;
return "RPC OK";
} else {
std::cout << status.error_code() << ": " << status.error_message() << std::endl;
return "RPC failed";
}
}
private:
std::unique_ptr<MSService::Stub> stub_;
};
int main(int argc, char **argv) {
// Instantiate the client. It requires a channel, out of which the actual RPCs
// are created. This channel models a connection to an endpoint specified by
// the argument "--target=" which is the only expected argument.
// We indicate that the channel isn't authenticated (use of
// InsecureChannelCredentials()).
std::string target_str;
std::string arg_target_str("--target");
if (argc > 1) {
// parse target
std::string arg_val = argv[1];
size_t start_pos = arg_val.find(arg_target_str);
if (start_pos != std::string::npos) {
start_pos += arg_target_str.size();
if (start_pos < arg_val.size() && arg_val[start_pos] == '=') {
target_str = arg_val.substr(start_pos + 1);
} else {
std::cout << "The only correct argument syntax is --target=" << std::endl;
return 0;
}
} else {
target_str = "localhost:5500";
}
} else {
target_str = "localhost:5500";
}
MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()));
std::string reply = client.Predict();
std::cout << "client received: " << reply << std::endl;
return 0;
}