Add route table service and proxy impl.

This commit is contained in:
ZPaC 2022-02-23 14:34:51 +08:00
parent b49da6cf95
commit ab5d35fab7
12 changed files with 227 additions and 24 deletions

View File

@ -15,16 +15,74 @@
*/
#include <string>
#include <vector>
#include "distributed/cluster/actor_route_table_proxy.h"
namespace mindspore {
namespace distributed {
namespace cluster {
bool ActorRouteTableProxy::RegisterRoute(const std::string &actor_id, const ActorAddress &actor_addr) { return true; }
bool ActorRouteTableProxy::RegisterRoute(const std::string &actor_id, const ActorAddress &actor_addr) {
MS_EXCEPTION_IF_NULL(node_);
std::shared_ptr<std::vector<unsigned char>> output = nullptr;
if (!node_->SendToScheduler(actor_addr.SerializeAsString().data(), actor_addr.SerializeAsString().size(),
NodeCommand::REGISTER_ACTOR_ROUTE, &output)) {
MS_LOG(EXCEPTION) << "Failed to send register route request to scheduler.";
}
bool ActorRouteTableProxy::DeleteRoute(const std::string &actor_id) { return true; }
GeneralResponseMsg register_route_rsp_msg;
MS_EXCEPTION_IF_NULL(output);
(void)register_route_rsp_msg.ParseFromArray(output->data(), SizeToInt(output->size()));
if (!register_route_rsp_msg.is_success()) {
MS_LOG(ERROR) << "Register route for actor " << actor_id << " failed. " << register_route_rsp_msg.error();
return false;
}
return true;
}
ActorAddress ActorRouteTableProxy::LookupRoute(const std::string &actor_id) const { return {}; }
bool ActorRouteTableProxy::DeleteRoute(const std::string &actor_id) {
MS_EXCEPTION_IF_NULL(node_);
std::shared_ptr<std::vector<unsigned char>> output = nullptr;
if (!node_->SendToScheduler(actor_id.data(), actor_id.size(), NodeCommand::DELETE_ACTOR_ROUTE, &output)) {
MS_LOG(EXCEPTION) << "Failed to send delete route request to scheduler.";
}
GeneralResponseMsg delete_route_rsp_msg;
MS_EXCEPTION_IF_NULL(output);
(void)delete_route_rsp_msg.ParseFromArray(output->data(), SizeToInt(output->size()));
if (!delete_route_rsp_msg.is_success()) {
MS_LOG(ERROR) << "Delete route for actor " << actor_id << " failed. " << delete_route_rsp_msg.error();
return false;
}
return true;
}
ActorAddress ActorRouteTableProxy::LookupRoute(const std::string &actor_id) const {
MS_EXCEPTION_IF_NULL(node_);
// Whether this lookup operation is successful.
bool lookup_success = false;
// Lookup last timestamp before timeout.
auto timeout_ts = CURRENT_TIMESTAMP_MILLI + lookup_timeout_;
std::shared_ptr<std::vector<unsigned char>> output = nullptr;
ActorAddress lookup_route_rsp_msg;
do {
if (!node_->SendToScheduler(actor_id.data(), actor_id.size(), NodeCommand::LOOKUP_ACTOR_ROUTE, &output)) {
MS_LOG(EXCEPTION) << "Failed to send lookup route request to scheduler.";
}
MS_EXCEPTION_IF_NULL(output);
(void)lookup_route_rsp_msg.ParseFromArray(output->data(), SizeToInt(output->size()));
// An actor route could not be registered yet because another process could be launched slow.
// If the response actor id is empty, this means the adderess is not registered yet.
if (lookup_route_rsp_msg.actor_id().empty()) {
MS_LOG(DEBUG) << "Actor route for actor " << actor_id << " is not registered yet, please try later.";
std::this_thread::sleep_for(std::chrono::milliseconds(kLookupInterval));
} else {
lookup_success = true;
}
} while (!lookup_success && CURRENT_TIMESTAMP_MILLI <= timeout_ts);
return lookup_route_rsp_msg;
}
} // namespace cluster
} // namespace distributed
} // namespace mindspore

View File

@ -19,19 +19,27 @@
#include <string>
#include <memory>
#include <chrono>
#include "proto/comm.pb.h"
#include "ps/core/node.h"
#include "ps/core/abstract_node.h"
#include "distributed/constants.h"
namespace mindspore {
namespace distributed {
namespace cluster {
using ps::core::ActorAddress;
using ps::core::GeneralResponseMsg;
using ps::core::NodeCommand;
// The time in milliseconds between two lookup operations.
constexpr auto kLookupInterval = 100;
// Actor route table proxy for nodes like workers and server. This class helps update actor route table in scheduler
// across the network.
class ActorRouteTableProxy {
public:
explicit ActorRouteTableProxy(const std::shared_ptr<ps::core::Node> &node) : node_(node) {}
explicit ActorRouteTableProxy(const std::shared_ptr<ps::core::AbstractNode> &node, uint32_t lookup_timout)
: node_(node), lookup_timeout_(std::chrono::milliseconds(lookup_timout)) {}
~ActorRouteTableProxy() = default;
// Register actor address to the route table stored in scheduler.
@ -45,7 +53,10 @@ class ActorRouteTableProxy {
private:
// The node variable helps proxy to communicate with scheduler, e.g., SendMessage.
std::shared_ptr<ps::core::Node> node_;
std::shared_ptr<ps::core::AbstractNode> node_;
// The timeout window for lookup route operation because time of route lookup_timout of each process is different.
std::chrono::milliseconds lookup_timeout_;
};
} // namespace cluster
} // namespace distributed

View File

@ -14,6 +14,8 @@
* limitations under the License.
*/
#include <mutex>
#include <shared_mutex>
#include "distributed/cluster/actor_route_table_service.h"
namespace mindspore {
@ -23,12 +25,36 @@ bool ActorRouteTableService::Initialize() { return true; }
bool ActorRouteTableService::RegisterRoute(const std::string &actor_id, const ActorAddress &actor_addr,
std::string *error) {
MS_ERROR_IF_NULL_W_RET_VAL(error, false);
std::unique_lock lock(mtx_);
if (actor_addresses_.count(actor_id) != 0) {
*error = "The address of actor id " + actor_id + " already exists.";
return false;
}
actor_addresses_[actor_id] = actor_addr;
return true;
}
bool ActorRouteTableService::DeleteRoute(const std::string &actor_id, std::string *error) { return true; }
bool ActorRouteTableService::DeleteRoute(const std::string &actor_id, std::string *error) {
MS_ERROR_IF_NULL_W_RET_VAL(error, false);
std::unique_lock lock(mtx_);
if (actor_addresses_.count(actor_id) == 0) {
*error = "The address of actor id " + actor_id + " does not exist.";
return false;
}
(void)actor_addresses_.erase(actor_id);
return true;
}
ActorAddress ActorRouteTableService::LookupRoute(const std::string &actor_id, std::string *error) { return {}; }
ActorAddress ActorRouteTableService::LookupRoute(const std::string &actor_id, std::string *error) {
MS_ERROR_IF_NULL_W_RET_VAL(error, {});
std::shared_lock lock(mtx_);
if (actor_addresses_.count(actor_id) == 0) {
*error = "The address of actor id " + actor_id + " does not exist.";
return {};
}
return actor_addresses_[actor_id];
}
} // namespace cluster
} // namespace distributed
} // namespace mindspore

View File

@ -18,10 +18,12 @@
#define MINDSPORE_CCSRC_DISTRIBUTED_CLUSTER_ACTOR_ROUTE_TABLE_SERVICE_H_
#include <map>
#include <mutex>
#include <string>
#include <memory>
#include <shared_mutex>
#include "proto/comm.pb.h"
#include "utils/log_adapter.h"
#include "distributed/constants.h"
namespace mindspore {

View File

@ -19,6 +19,7 @@
#include <set>
#include <map>
#include <chrono>
#include <string>
namespace mindspore {
@ -46,6 +47,10 @@ constexpr int MAX_HOSTNAME_LEN = 1024;
const uint16_t kDefaultSchedPort = 6667;
const uint16_t kMaxPort = 65535;
constexpr uint32_t kDefaultFinishTimeout = 30;
// This macro the current timestamp in milliseconds.
#define CURRENT_TIMESTAMP_MILLI \
std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch())
} // namespace distributed
} // namespace mindspore
#endif // MINDSPORE_CCSRC_DISTRIBUTED_CONSTANTS_H_

View File

@ -149,10 +149,11 @@ void GraphSplitter::SplitGraph(const std::vector<SplitGraphSegment> &segments,
MS_LOG(EXCEPTION) << "This segment is empty.";
return;
}
if (node_labels_[nodes[0]] != segment.label) {
MS_LOG(EXCEPTION) << "Node label " << node_labels_[nodes[0]].to_string() << " is not the same as segment label "
<< segment.label.to_string();
return;
auto segment_first_node = nodes[0];
if (node_labels_[segment_first_node] != segment.label) {
MS_LOG(EXCEPTION) << "Node label " << node_labels_[segment_first_node].to_string()
<< " is not the same as segment label " << segment.label.to_string();
}
// Add Depend between in-degree and out-degree of this segment because the execution order should be kept
@ -199,7 +200,6 @@ void GraphSplitter::SplitGraph(const std::vector<SplitGraphSegment> &segments,
if (send_label == recv_label) {
MS_LOG(EXCEPTION) << "The Send and Recv must have different label. But got Send: " << send_label.to_string()
<< ", Recv: " << recv_label.to_string();
return;
}
if (recv_label == this_process_label_) {
@ -227,7 +227,6 @@ OperatorLabel GraphSplitter::GetSplitLabel(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Only CNode has distributed split label.";
return default_label_;
}
CNodePtr cnode = node->cast<CNodePtr>();
auto prim_node = cnode->input(0);
@ -301,11 +300,17 @@ CNodePtr GraphSplitter::GenerateRecvNode(const AnfNodePtr &input, const AnfNodeP
std::vector<AnfNodePtr> recv_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcRecvOpName))};
if (IsPrimitiveCNode(input, prim::kPrimUpdateState)) {
ValuePtr monad_value = nullptr;
if (HasAbstractUMonad(input)) {
auto monad_input = NewValueNode(kUMonad);
monad_input->set_abstract(kUMonad->ToAbstract());
recv_inputs.push_back(monad_input);
monad_value = kUMonad;
} else if (HasAbstractIOMonad(input)) {
monad_value = kIOMonad;
} else {
MS_LOG(EXCEPTION) << "The input is PrimUpdateState must have monad abstract.";
}
auto monad_input = NewValueNode(monad_value);
monad_input->set_abstract(monad_value->ToAbstract());
recv_inputs.push_back(monad_input);
} else {
auto mock_value = GenerateMockValueNode(true, input);
MS_EXCEPTION_IF_NULL(mock_value);
@ -332,7 +337,6 @@ void GraphSplitter::SetSendNodeAttr(const AnfNodePtr &send_node, const AnfNodePt
std::string to_node_name = send_to_node->fullname_with_scope();
if (node_labels_.count(send_to_node) == 0) {
MS_LOG(EXCEPTION) << "Send to node " << to_node_name << " has no operator label.";
return;
}
// These attributes are the inter-process edge information.
@ -358,7 +362,6 @@ void GraphSplitter::SetRecvNodeAttr(const AnfNodePtr &recv_node, const AnfNodePt
std::string to_node_name = recv_to_node->fullname_with_scope();
if (node_labels_.count(recv_from_node) == 0) {
MS_LOG(EXCEPTION) << "Recv from node " << from_node_name << " has no operator label.";
return;
}
// These attributes are the inter-process edge information.
@ -431,7 +434,6 @@ bool GraphSplitter::IsNodesWithSameLabel(const AnfNodePtr &node1, const AnfNodeP
if (node_labels_.count(node1) == 0 || node_labels_.count(node2) == 0) {
MS_LOG(EXCEPTION) << "Either 'node1': " << node1->fullname_with_scope()
<< " or 'node2': " << node2->fullname_with_scope() << " is not marked with split label.";
return false;
}
return node_labels_[node1] == node_labels_[node2];
}

View File

@ -287,6 +287,39 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
return Wait(request_id, timeout);
}
bool AbstractNode::SendToScheduler(const void *message, size_t len, NodeCommand node_cmd, VectorPtr *output,
const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(message);
uint32_t expected_reponse_num = 1;
uint64_t request_id = AddMessageTrack(expected_reponse_num);
auto message_meta = std::make_shared<MessageMeta>();
MS_EXCEPTION_IF_NULL(message_meta);
message_meta->set_cmd(node_cmd);
message_meta->set_request_id(request_id);
MS_EXCEPTION_IF_NULL(client_to_scheduler_);
if (!client_to_scheduler_->SendMessage(message_meta, Protos::RAW, message, len)) {
MS_LOG(WARNING) << "Failed to send message" << node_cmd << "to scheduler.";
}
bool ret = Wait(request_id, timeout);
if (!ret) {
MS_LOG(ERROR) << "Sending message " << node_cmd << " to scheduler timeout.";
return ret;
}
// Assign the response value from scheduler.
if (output != nullptr) {
if (received_scheduler_messages_.count(request_id) == 0) {
MS_LOG(ERROR) << "The response message of " << node_cmd << " is not received yet.";
return false;
}
*output = received_scheduler_messages_[request_id];
}
return ret;
}
uint64_t AbstractNode::CollectiveSendAsync(const NodeRole &node_role, const uint32_t &rank_id, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(data);
@ -615,6 +648,25 @@ void AbstractNode::ProcessFetchServersResp(const std::shared_ptr<MessageMeta> &m
}
}
void AbstractNode::ProcessActorRouteServiceResp(const std::shared_ptr<MessageMeta> &meta, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
std::lock_guard<std::mutex> lock(receive_messages_mutex_);
const uint64_t request_id = meta->request_id();
VectorPtr received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
if (size > 0) {
size_t dest_size = size;
size_t src_size = size;
auto ret = memcpy_s(received_data.get()->data(), dest_size, data, src_size);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
}
received_scheduler_messages_[request_id] = received_data;
}
void AbstractNode::ProcessSendMetadata(const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
size_t size) {
@ -1088,6 +1140,13 @@ void AbstractNode::InitCommandHandler() {
handlers_[NodeCommand::SCALE_OUT_DONE] = nullptr;
handlers_[NodeCommand::SCALE_IN_DONE] = nullptr;
handlers_[NodeCommand::SEND_EVENT] = nullptr;
RegisterActorRouteTableRspHandler();
}
void AbstractNode::RegisterActorRouteTableRspHandler() {
handlers_[NodeCommand::REGISTER_ACTOR_ROUTE] = &AbstractNode::ProcessActorRouteServiceResp;
handlers_[NodeCommand::DELETE_ACTOR_ROUTE] = &AbstractNode::ProcessActorRouteServiceResp;
handlers_[NodeCommand::LOOKUP_ACTOR_ROUTE] = &AbstractNode::ProcessActorRouteServiceResp;
}
void AbstractNode::InitServerHandler() {

View File

@ -98,6 +98,10 @@ class AbstractNode : public Node {
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &msgs,
int command, std::vector<VectorPtr> *output = nullptr, const uint32_t &timeout = kCommTimeoutInSeconds);
// The interface that sends sync message to the scheduler.
bool SendToScheduler(const void *message, size_t len, NodeCommand command, VectorPtr *output = nullptr,
const uint32_t &timeout = kCommTimeoutInSeconds);
uint64_t CollectiveSendAsync(const NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size);
std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const NodeRole &node_role, const uint32_t &rank_id,
VectorPtr *output);
@ -150,6 +154,9 @@ class AbstractNode : public Node {
void ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
void ProcessFetchServersResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
// Process the response messages about actor route table service.
void ProcessActorRouteServiceResp(const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size);
void ProcessSendMetadata(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size);
void ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
@ -201,6 +208,7 @@ class AbstractNode : public Node {
uint64_t NextExpectedRankRequestId(const uint32_t &rank_id);
uint64_t NextActualRankRequestId(const uint32_t &rank_id);
void InitCommandHandler();
void RegisterActorRouteTableRspHandler();
void InitServerHandler();
// when initializing the node, should initializing the node info.

View File

@ -103,9 +103,13 @@ bool Node::Wait(uint64_t request_id, const uint32_t &timeout) {
tracker_lock.unlock();
std::unique_lock<std::mutex> msgs_lock(receive_messages_mutex_);
// The messages should be already copied before message_tracker_cond_ is notified.
if (receive_messages_.count(request_id) != 0) {
(void)receive_messages_.erase(request_id);
}
if (received_scheduler_messages_.count(request_id) != 0) {
(void)received_scheduler_messages_.erase(request_id);
}
msgs_lock.unlock();
return res;
}

View File

@ -139,6 +139,9 @@ class Node {
std::unordered_map<uint64_t, std::unordered_map<uint32_t, VectorPtr>> workder_receive_messages_;
std::map<std::pair<uint32_t, uint64_t>, bool> receive_messages_done_;
std::mutex receive_messages_mutex_;
// Message from the scheduler. The key is: request_id, the value is:RecvMessage.
std::unordered_map<uint64_t, VectorPtr> received_scheduler_messages_;
};
} // namespace core
} // namespace ps

View File

@ -488,6 +488,9 @@ void SchedulerNode::ProcessSendEvent(const std::shared_ptr<TcpServer> &server,
void SchedulerNode::ProcessRegisterActorRoute(const std::shared_ptr<TcpServer> &server,
const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
MS_ERROR_IF_NULL_WO_RET_VAL(server);
MS_ERROR_IF_NULL_WO_RET_VAL(conn);
MS_ERROR_IF_NULL_WO_RET_VAL(meta);
MS_ERROR_IF_NULL_WO_RET_VAL(data);
MS_ERROR_IF_NULL_WO_RET_VAL(actor_route_table_service_);
ActorAddress actor_address;
@ -501,11 +504,35 @@ void SchedulerNode::ProcessRegisterActorRoute(const std::shared_ptr<TcpServer> &
void SchedulerNode::ProcessDeleteActorRoute(const std::shared_ptr<TcpServer> &server,
const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {}
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
MS_ERROR_IF_NULL_WO_RET_VAL(server);
MS_ERROR_IF_NULL_WO_RET_VAL(conn);
MS_ERROR_IF_NULL_WO_RET_VAL(meta);
MS_ERROR_IF_NULL_WO_RET_VAL(data);
MS_ERROR_IF_NULL_WO_RET_VAL(actor_route_table_service_);
std::string actor_id(static_cast<const char *>(data), size);
std::string error = "";
bool ret = actor_route_table_service_->DeleteRoute(actor_id, &error);
GeneralResponse(server, conn, meta, ret, error);
}
void SchedulerNode::ProcessLookupActorRoute(const std::shared_ptr<TcpServer> &server,
const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {}
const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size) {
MS_ERROR_IF_NULL_WO_RET_VAL(server);
MS_ERROR_IF_NULL_WO_RET_VAL(conn);
MS_ERROR_IF_NULL_WO_RET_VAL(meta);
MS_ERROR_IF_NULL_WO_RET_VAL(data);
MS_ERROR_IF_NULL_WO_RET_VAL(actor_route_table_service_);
std::string actor_id(static_cast<const char *>(data), size);
std::string error = "";
ActorAddress address = actor_route_table_service_->LookupRoute(actor_id, &error);
if (!server->SendMessage(conn, meta, Protos::PROTOBUF, address.SerializeAsString().data(), address.ByteSizeLong())) {
MS_LOG(ERROR) << "Scheduler failed to respond message for lookup route.";
}
}
bool SchedulerNode::SendPrepareBuildingNetwork(const std::unordered_map<std::string, NodeInfo> &node_infos) {
uint64_t request_id = AddMessageTrack(node_infos.size());

View File

@ -48,7 +48,6 @@ void RpcNodeScheduler::Link(const ActorSetPtr &) {
<< " is invalid. send_dst_ranks: " << send_dst_ranks << ", send_dst_roles: " << send_dst_roles
<< ", send_src_node_name: " << send_src_node_name
<< ", send_dst_node_name: " << send_dst_node_name;
return;
}
send_actor->SetRouteInfo(send_dst_ranks[0], send_dst_roles[0], send_src_node_name, send_dst_node_name);
}
@ -66,7 +65,6 @@ void RpcNodeScheduler::Link(const ActorSetPtr &) {
<< " is invalid. recv_src_ranks: " << recv_src_ranks << ", recv_src_roles: " << recv_src_roles
<< ", recv_src_node_name: " << recv_src_node_name
<< ", recv_dst_node_name: " << recv_dst_node_name;
return;
}
recv_actor->SetRouteInfo(recv_src_ranks[0], recv_src_roles[0], recv_src_node_name, recv_dst_node_name);
}