diff --git a/cmake/package.cmake b/cmake/package.cmake index 5504dc6848d..3163af4bfb0 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -73,6 +73,16 @@ if (USE_GLOG) ) endif () +file(GLOB_RECURSE LIBEVENT_LIB_LIST + ${libevent_LIBPATH}/libevent* + ${libevent_LIBPATH}/libevent_pthreads* + ) + +install( + FILES ${LIBEVENT_LIB_LIST} + DESTINATION ${INSTALL_LIB_DIR} + COMPONENT mindspore +) if (ENABLE_MINDDATA) install( TARGETS _c_dataengine _c_mindrecord diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 128506c2093..22edb74b6b7 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -278,7 +278,7 @@ if (CMAKE_SYSTEM_NAME MATCHES "Windows") target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive) else () if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) - target_link_libraries(mindspore mindspore::pslite mindspore::protobuf ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) + target_link_libraries(mindspore mindspore::pslite mindspore::protobuf mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) if (${ENABLE_IBVERBS} STREQUAL "ON") target_link_libraries(mindspore ibverbs rdmacm) endif() diff --git a/mindspore/ccsrc/ps/CMakeLists.txt b/mindspore/ccsrc/ps/CMakeLists.txt index e5df5353cb6..9bc3f313120 100644 --- a/mindspore/ccsrc/ps/CMakeLists.txt +++ b/mindspore/ccsrc/ps/CMakeLists.txt @@ -5,6 +5,8 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info.cc") list(REMOVE_ITEM _PS_SRC_FILES "scheduler.cc") list(REMOVE_ITEM _PS_SRC_FILES "util.cc") + list(REMOVE_ITEM _PS_SRC_FILES "comm/http_message_handler.cc") + list(REMOVE_ITEM _PS_SRC_FILES "comm/http_server.cc") endif() set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) diff --git a/mindspore/ccsrc/ps/comm/http_message_handler.cc b/mindspore/ccsrc/ps/comm/http_message_handler.cc new file mode 100644 index 00000000000..c09dba168c8 --- /dev/null +++ b/mindspore/ccsrc/ps/comm/http_message_handler.cc @@ -0,0 +1,222 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/comm/http_message_handler.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace mindspore { +namespace ps { +namespace comm { + +HttpMessageHandler::~HttpMessageHandler() { + if (!event_request_) { + evhttp_request_free(event_request_); + event_request_ = nullptr; + } + if (!event_uri_) { + evhttp_uri_free(const_cast(event_uri_)); + event_uri_ = nullptr; + } + if (!resp_buf_) { + evbuffer_free(resp_buf_); + resp_buf_ = nullptr; + } +} + +void HttpMessageHandler::InitHttpMessage() { + MS_EXCEPTION_IF_NULL(event_request_); + event_uri_ = evhttp_request_get_evhttp_uri(event_request_); + MS_EXCEPTION_IF_NULL(event_uri_); + + const char *query = evhttp_uri_get_query(event_uri_); + if (query) { + evhttp_parse_query_str(query, &path_params_); + } + + head_params_ = evhttp_request_get_input_headers(event_request_); + resp_headers_ = evhttp_request_get_output_headers(event_request_); + resp_buf_ = evhttp_request_get_output_buffer(event_request_); +} + +std::string HttpMessageHandler::GetHeadParam(const std::string &key) { + MS_EXCEPTION_IF_NULL(head_params_); + const char *val = evhttp_find_header(head_params_, key.c_str()); + MS_EXCEPTION_IF_NULL(val); + return std::string(val); +} + +std::string HttpMessageHandler::GetPathParam(const std::string &key) { + const char *val = evhttp_find_header(&path_params_, key.c_str()); + MS_EXCEPTION_IF_NULL(val); + return std::string(val); +} + +void HttpMessageHandler::ParsePostParam() { + MS_EXCEPTION_IF_NULL(event_request_); + size_t len = evbuffer_get_length(event_request_->input_buffer); + if (len == 0) { + MS_LOG(EXCEPTION) << "The post parameter size is: " << len; + } + post_param_parsed_ = true; + const char *post_message = reinterpret_cast(evbuffer_pullup(event_request_->input_buffer, -1)); + MS_EXCEPTION_IF_NULL(post_message); + body_ = std::make_unique(post_message, len); + int ret = evhttp_parse_query_str(body_->c_str(), &post_params_); + if (ret == -1) { + MS_LOG(EXCEPTION) << "Parse post parameter failed!"; + } +} + +std::string HttpMessageHandler::GetPostParam(const std::string &key) { + if (!post_param_parsed_) { + ParsePostParam(); + } + + const char *val = evhttp_find_header(&post_params_, key.c_str()); + MS_EXCEPTION_IF_NULL(val); + return std::string(val); +} + +std::string HttpMessageHandler::GetRequestUri() { + MS_EXCEPTION_IF_NULL(event_request_); + const char *uri = evhttp_request_get_uri(event_request_); + MS_EXCEPTION_IF_NULL(uri); + return std::string(uri); +} + +std::string HttpMessageHandler::GetRequestHost() { + MS_EXCEPTION_IF_NULL(event_request_); + const char *host = evhttp_request_get_host(event_request_); + MS_EXCEPTION_IF_NULL(host); + return std::string(host); +} + +int HttpMessageHandler::GetUriPort() { + MS_EXCEPTION_IF_NULL(event_uri_); + return evhttp_uri_get_port(event_uri_); +} + +std::string HttpMessageHandler::GetUriPath() { + MS_EXCEPTION_IF_NULL(event_uri_); + const char *path = evhttp_uri_get_path(event_uri_); + MS_EXCEPTION_IF_NULL(path); + return std::string(path); +} + +std::string HttpMessageHandler::GetUriQuery() { + MS_EXCEPTION_IF_NULL(event_uri_); + const char *query = evhttp_uri_get_query(event_uri_); + MS_EXCEPTION_IF_NULL(query); + return std::string(query); +} + +std::string HttpMessageHandler::GetUriFragment() { + MS_EXCEPTION_IF_NULL(event_uri_); + const char *fragment = evhttp_uri_get_fragment(event_uri_); + MS_EXCEPTION_IF_NULL(fragment); + return std::string(fragment); +} + +std::string HttpMessageHandler::GetPostMsg() { + MS_EXCEPTION_IF_NULL(event_request_); + if (body_ != nullptr) { + return *body_; + } + size_t len = evbuffer_get_length(event_request_->input_buffer); + if (len == 0) { + MS_LOG(EXCEPTION) << "The post message is empty!"; + } + const char *post_message = reinterpret_cast(evbuffer_pullup(event_request_->input_buffer, -1)); + MS_EXCEPTION_IF_NULL(post_message); + body_ = std::make_unique(post_message, len); + return *body_; +} + +void HttpMessageHandler::AddRespHeadParam(const std::string &key, const std::string &val) { + MS_EXCEPTION_IF_NULL(resp_headers_); + if (evhttp_add_header(resp_headers_, key.c_str(), val.c_str()) != 0) { + MS_LOG(EXCEPTION) << "Add parameter of response header failed."; + } +} + +void HttpMessageHandler::AddRespHeaders(const HttpHeaders &headers) { + for (auto iter = headers.begin(); iter != headers.end(); ++iter) { + auto list = iter->second; + for (auto iterator_val = list.begin(); iterator_val != list.end(); ++iterator_val) { + AddRespHeadParam(iter->first, *iterator_val); + } + } +} + +void HttpMessageHandler::AddRespString(const std::string &str) { + MS_EXCEPTION_IF_NULL(resp_buf_); + if (evbuffer_add_printf(resp_buf_, "%s", str.c_str()) == -1) { + MS_LOG(EXCEPTION) << "Add string to response body failed."; + } +} + +void HttpMessageHandler::SetRespCode(int code) { resp_code_ = code; } + +void HttpMessageHandler::SendResponse() { + MS_EXCEPTION_IF_NULL(event_request_); + MS_EXCEPTION_IF_NULL(resp_buf_); + evhttp_send_reply(event_request_, resp_code_, nullptr, resp_buf_); +} + +void HttpMessageHandler::QuickResponse(int code, const std::string &body) { + MS_EXCEPTION_IF_NULL(event_request_); + MS_EXCEPTION_IF_NULL(resp_buf_); + AddRespString(body); + evhttp_send_reply(event_request_, code, nullptr, resp_buf_); +} + +void HttpMessageHandler::SimpleResponse(int code, const HttpHeaders &headers, const std::string &body) { + MS_EXCEPTION_IF_NULL(event_request_); + MS_EXCEPTION_IF_NULL(resp_buf_); + AddRespHeaders(headers); + AddRespString(body); + MS_EXCEPTION_IF_NULL(resp_buf_); + evhttp_send_reply(event_request_, resp_code_, nullptr, resp_buf_); +} + +void HttpMessageHandler::RespError(int nCode, const std::string &message) { + MS_EXCEPTION_IF_NULL(event_request_); + if (message.empty()) { + evhttp_send_error(event_request_, nCode, nullptr); + } else { + evhttp_send_error(event_request_, nCode, message.c_str()); + } +} +} // namespace comm +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/comm/http_message_handler.h b/mindspore/ccsrc/ps/comm/http_message_handler.h new file mode 100644 index 00000000000..ad9bc6a139d --- /dev/null +++ b/mindspore/ccsrc/ps/comm/http_message_handler.h @@ -0,0 +1,107 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_COMM_HTTP_MESSAGE_HANDLER_H_ +#define MINDSPORE_CCSRC_PS_COMM_HTTP_MESSAGE_HANDLER_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" + +namespace mindspore { +namespace ps { +namespace comm { + +typedef std::map> HttpHeaders; + +class HttpMessageHandler { + public: + explicit HttpMessageHandler(struct evhttp_request *req) + : event_request_(req), + event_uri_(nullptr), + path_params_{0}, + head_params_(nullptr), + post_params_{0}, + post_param_parsed_(false), + body_(nullptr), + resp_headers_(nullptr), + resp_buf_(nullptr), + resp_code_(HTTP_OK) {} + + ~HttpMessageHandler(); + + void InitHttpMessage(); + std::string GetRequestUri(); + std::string GetRequestHost(); + std::string GetHeadParam(const std::string &key); + std::string GetPathParam(const std::string &key); + std::string GetPostParam(const std::string &key); + std::string GetPostMsg(); + std::string GetUriPath(); + std::string GetUriQuery(); + + // It will return -1 if no port set + int GetUriPort(); + + // Useless to get from a request url, fragment is only for browser to locate sth. + std::string GetUriFragment(); + + void AddRespHeadParam(const std::string &key, const std::string &val); + void AddRespHeaders(const HttpHeaders &headers); + void AddRespString(const std::string &str); + void SetRespCode(int code); + + // Make sure code and all response body has finished set + void SendResponse(); + void QuickResponse(int code, const std::string &body); + void SimpleResponse(int code, const HttpHeaders &headers, const std::string &body); + + // If message is empty, libevent will use default error code message instead + void RespError(int nCode, const std::string &message); + + private: + struct evhttp_request *event_request_; + const struct evhttp_uri *event_uri_; + struct evkeyvalq path_params_; + struct evkeyvalq *head_params_; + struct evkeyvalq post_params_; + bool post_param_parsed_; + std::unique_ptr body_; + struct evkeyvalq *resp_headers_; + struct evbuffer *resp_buf_; + int resp_code_; + + // Body length should no more than MAX_POST_BODY_LEN, default 64kB + void ParsePostParam(); +}; + +} // namespace comm +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_COMM_HTTP_MESSAGE_HANDLER_H_ diff --git a/mindspore/ccsrc/ps/comm/http_server.cc b/mindspore/ccsrc/ps/comm/http_server.cc new file mode 100644 index 00000000000..f99ec69e4af --- /dev/null +++ b/mindspore/ccsrc/ps/comm/http_server.cc @@ -0,0 +1,186 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/comm/http_server.h" +#include "ps/comm/http_message_handler.h" + +#ifdef WIN32 +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace mindspore { +namespace ps { +namespace comm { + +HttpServer::~HttpServer() { + if (event_http_) { + evhttp_free(event_http_); + event_http_ = nullptr; + } + if (event_base_) { + event_base_free(event_base_); + event_base_ = nullptr; + } +} + +bool HttpServer::InitServer() { + if (!CheckIp(server_address_)) { + MS_LOG(EXCEPTION) << "Server address" << server_address_ << " illegal!"; + } + int64_t uAddr = inet_addr(server_address_.c_str()); + if (INADDR_NONE == uAddr) { + MS_LOG(EXCEPTION) << "Server address illegal, inet_addr converting failed!"; + } + if (server_port_ <= 0) { + MS_LOG(EXCEPTION) << "Server port:" << server_port_ << " illegal!"; + } + + event_base_ = event_base_new(); + MS_EXCEPTION_IF_NULL(event_base_); + event_http_ = evhttp_new(event_base_); + MS_EXCEPTION_IF_NULL(event_http_); + int ret = evhttp_bind_socket(event_http_, server_address_.c_str(), server_port_); + if (ret != 0) { + MS_LOG(EXCEPTION) << "Http bind server addr:" << server_address_.c_str() << " port:" << server_port_ << "failed"; + } + is_init_ = true; + return true; +} + +bool HttpServer::CheckIp(const std::string &ip) { + std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"); + std::smatch res; + if (regex_match(ip, res, pattern)) { + return true; + } + return false; +} + +void HttpServer::SetTimeOut(int seconds) { + MS_EXCEPTION_IF_NULL(event_http_); + if (seconds < 0) { + MS_LOG(EXCEPTION) << "The timeout seconds:" << seconds << "is less than 0!"; + } + evhttp_set_timeout(event_http_, seconds); +} + +void HttpServer::SetAllowedMethod(HttpMethodsSet methods) { + MS_EXCEPTION_IF_NULL(event_http_); + evhttp_set_allowed_methods(event_http_, methods); +} + +void HttpServer::SetMaxHeaderSize(size_t num) { + MS_EXCEPTION_IF_NULL(event_http_); + if (num < 0) { + MS_LOG(EXCEPTION) << "The header num:" << num << "is less than 0!"; + } + evhttp_set_max_headers_size(event_http_, num); +} + +void HttpServer::SetMaxBodySize(size_t num) { + MS_EXCEPTION_IF_NULL(event_http_); + if (num < 0) { + MS_LOG(EXCEPTION) << "The max body num:" << num << "is less than 0!"; + } + evhttp_set_max_body_size(event_http_, num); +} + +bool HttpServer::RegisterRoute(const std::string &url, handle_t *function) { + if ((!is_init_) && (!InitServer())) { + MS_LOG(EXCEPTION) << "Init http server failed!"; + } + HandlerFunc func = function; + if (!func) { + return false; + } + + auto TransFunc = [](struct evhttp_request *req, void *arg) { + MS_EXCEPTION_IF_NULL(req); + MS_EXCEPTION_IF_NULL(arg); + HttpMessageHandler httpReq(req); + httpReq.InitHttpMessage(); + handle_t *f = reinterpret_cast(arg); + f(&httpReq); + }; + handle_t **pph = func.target(); + MS_EXCEPTION_IF_NULL(pph); + MS_EXCEPTION_IF_NULL(event_http_); + + // O SUCCESS,-1 ALREADY_EXIST,-2 FAILURE + int ret = evhttp_set_cb(event_http_, url.c_str(), TransFunc, reinterpret_cast(*pph)); + if (ret == 0) { + MS_LOG(INFO) << "Ev http register handle of:" << url.c_str() << " success."; + } else if (ret == -1) { + MS_LOG(WARNING) << "Ev http register handle of:" << url.c_str() << " exist."; + } else { + MS_LOG(ERROR) << "Ev http register handle of:" << url.c_str() << " failed."; + return false; + } + return true; +} + +bool HttpServer::UnRegisterRoute(const std::string &url) { + MS_EXCEPTION_IF_NULL(event_http_); + return (evhttp_del_cb(event_http_, url.c_str()) == 0); +} + +bool HttpServer::Start() { + MS_EXCEPTION_IF_NULL(event_base_); + int ret = event_base_dispatch(event_base_); + if (ret == 0) { + MS_LOG(INFO) << "Event base dispatch success!"; + return true; + } else if (ret == 1) { + MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!"; + return false; + } else if (ret == -1) { + MS_LOG(ERROR) << "Event base dispatch failed with error occurred!"; + return false; + } else { + MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!"; + } +} + +void HttpServer::Stop() { + if (event_http_) { + evhttp_free(event_http_); + event_http_ = nullptr; + } + if (event_base_) { + event_base_free(event_base_); + event_base_ = nullptr; + } +} + +} // namespace comm +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/comm/http_server.h b/mindspore/ccsrc/ps/comm/http_server.h new file mode 100644 index 00000000000..79d1387e9e7 --- /dev/null +++ b/mindspore/ccsrc/ps/comm/http_server.h @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_COMM_HTTP_SERVER_H_ +#define MINDSPORE_CCSRC_PS_COMM_HTTP_SERVER_H_ + +#include "ps/comm/http_message_handler.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace mindspore { +namespace ps { +namespace comm { + +typedef enum eHttpMethod { + HM_GET = 1 << 0, + HM_POST = 1 << 1, + HM_HEAD = 1 << 2, + HM_PUT = 1 << 3, + HM_DELETE = 1 << 4, + HM_OPTIONS = 1 << 5, + HM_TRACE = 1 << 6, + HM_CONNECT = 1 << 7, + HM_PATCH = 1 << 8 +} HttpMethod; + +typedef u_int16_t HttpMethodsSet; + +typedef void(handle_t)(HttpMessageHandler *); + +class HttpServer { + public: + // Server address only support IPV4 now, and should be in format of "x.x.x.x" + explicit HttpServer(const std::string &address, std::int16_t port) + : server_address_(address), server_port_(port), event_base_(nullptr), event_http_(nullptr), is_init_(false) {} + + ~HttpServer(); + + typedef std::function HandlerFunc; + + bool InitServer(); + static bool CheckIp(const std::string &ip); + void SetTimeOut(int seconds = 5); + + // Default allowed methods: GET, POST, HEAD, PUT, DELETE + void SetAllowedMethod(HttpMethodsSet methods); + + // Default to ((((unsigned long long)0xffffffffUL) << 32) | 0xffffffffUL) + void SetMaxHeaderSize(std::size_t num); + + // Default to ((((unsigned long long)0xffffffffUL) << 32) | 0xffffffffUL) + void SetMaxBodySize(std::size_t num); + + // Return: true if success, false if failed, check log to find failure reason + bool RegisterRoute(const std::string &url, handle_t *func); + bool UnRegisterRoute(const std::string &url); + + bool Start(); + void Stop(); + + private: + std::string server_address_; + std::int16_t server_port_; + struct event_base *event_base_; + struct evhttp *event_http_; + bool is_init_; +}; + +} // namespace comm +} // namespace ps +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PS_COMM_HTTP_SERVER_H_ diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index 77b095ba848..05092db3c2f 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -55,6 +55,7 @@ if(ENABLE_MINDDATA) ./transform/*.cc ./utils/*.cc ./vm/*.cc + ./ps/*.cc ) if(NOT ENABLE_PYTHON) @@ -128,6 +129,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore/ccsrc/backend/session/kernel_build_client.cc" "../../../mindspore/ccsrc/transform/graph_ir/*.cc" "../../../mindspore/ccsrc/transform/graph_ir/op_declare/*.cc" + "../../../mindspore/ccsrc/ps/*.cc" ) list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc") diff --git a/tests/ut/cpp/ps/comm/http_server_test.cc b/tests/ut/cpp/ps/comm/http_server_test.cc new file mode 100644 index 00000000000..514398cf825 --- /dev/null +++ b/tests/ut/cpp/ps/comm/http_server_test.cc @@ -0,0 +1,125 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/comm/http_server.h" +#include "common/common_test.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mindspore { +namespace ps { +namespace comm { + +class TestHttpServer : public UT::Common { + public: + TestHttpServer() {} + + static void testGetHandler(HttpMessageHandler *resp) { + std::string host = resp->GetRequestHost(); + EXPECT_STREQ(host.c_str(), "127.0.0.1"); + + std::string path_param = resp->GetPathParam("key1"); + std::string header_param = resp->GetHeadParam("headerKey"); + std::string post_param = resp->GetPostParam("postKey"); + std::string post_message = resp->GetPostMsg(); + EXPECT_STREQ(path_param.c_str(), "value1"); + EXPECT_STREQ(header_param.c_str(), "headerValue"); + EXPECT_STREQ(post_param.c_str(), "postValue"); + EXPECT_STREQ(post_message.c_str(), "postKey=postValue"); + + const std::string rKey("headKey"); + const std::string rVal("headValue"); + const std::string rBody("post request success!\n"); + resp->AddRespHeadParam(rKey, rVal); + resp->AddRespString(rBody); + + resp->SetRespCode(200); + resp->SendResponse(); + } + + void SetUp() override { + server_ = new HttpServer("0.0.0.0", 9999); + server_->RegisterRoute("/httpget", [](HttpMessageHandler *resp) { + EXPECT_STREQ(resp->GetPathParam("key1").c_str(), "value1"); + EXPECT_STREQ(resp->GetUriQuery().c_str(), "key1=value1"); + EXPECT_STREQ(resp->GetRequestUri().c_str(), "/httpget?key1=value1"); + EXPECT_STREQ(resp->GetUriPath().c_str(), "/httpget"); + resp->QuickResponse(200, "get request success!\n"); + }); + server_->RegisterRoute("/handler", TestHttpServer::testGetHandler); + std::unique_ptr http_server_thread_(nullptr); + http_server_thread_.reset(new std::thread([&]() { server_->Start(); })); + http_server_thread_->detach(); + } + + void TearDown() override { server_->Stop(); } + + private: + HttpServer *server_; +}; + +TEST_F(TestHttpServer, httpGetQequest) { + char buffer[100]; + FILE *file; + std::string cmd = "curl -X GET http://127.0.0.1:9999/httpget?key1=value1"; + std::string result; + const char *sysCommand = cmd.data(); + if ((file = popen(sysCommand, "r")) == nullptr) { + return; + } + while (fgets(buffer, sizeof(buffer) - 1, file) != nullptr) { + result += buffer; + } + EXPECT_STREQ("get request success!\n", result.c_str()); + pclose(file); +} + +TEST_F(TestHttpServer, messageHandler) { + char buffer[100]; + FILE *file; + std::string cmd = + R"(curl -X POST -d 'postKey=postValue' -i -H "Accept: application/json" -H "headerKey: headerValue" http://127.0.0.1:9999/handler?key1=value1)"; + std::string result; + const char *sysCommand = cmd.data(); + if ((file = popen(sysCommand, "r")) == nullptr) { + return; + } + while (fgets(buffer, sizeof(buffer) - 1, file) != nullptr) { + result += buffer; + } + EXPECT_STREQ("post request success!\n", result.substr(result.find("post")).c_str()); + pclose(file); +} + +TEST_F(TestHttpServer, portException) { + HttpServer *server_exception = new HttpServer("0.0.0.0", -1); + ASSERT_THROW(server_exception->RegisterRoute("/handler", TestHttpServer::testGetHandler), std::exception); +} + +TEST_F(TestHttpServer, addressException) { + HttpServer *server_exception = new HttpServer("12344.0.0.0", 9998); + ASSERT_THROW(server_exception->RegisterRoute("/handler", TestHttpServer::testGetHandler), std::exception); +} + +} // namespace comm +} // namespace ps +} // namespace mindspore