Adapt for ps test case

This commit is contained in:
ZPaC 2022-04-29 18:43:53 +08:00
parent c29d6bb764
commit 78bd48fc57
12 changed files with 94 additions and 26 deletions

View File

@ -25,6 +25,11 @@
namespace mindspore { namespace mindspore {
namespace distributed { namespace distributed {
namespace rpc { namespace rpc {
// Print error message every 1000 times and sleep for 5ms in case the log file is too large.
static size_t kPrintCount = 0;
size_t kPrintCountInterval = 1000;
size_t kPrintTimeInterval = 50000;
// Handle socket events like read/write. // Handle socket events like read/write.
void SocketEventHandler(int fd, uint32_t events, void *context) { void SocketEventHandler(int fd, uint32_t events, void *context) {
Connection *conn = reinterpret_cast<Connection *>(context); Connection *conn = reinterpret_cast<Connection *>(context);
@ -61,9 +66,13 @@ void SocketEventHandler(int fd, uint32_t events, void *context) {
(conn->recv_message_type != ParseType::kHttpReq && conn->recv_message_type != ParseType::kHttpRsp && (conn->recv_message_type != ParseType::kHttpReq && conn->recv_message_type != ParseType::kHttpRsp &&
(events & (uint32_t)(EPOLLHUP | EPOLLRDHUP | EPOLLERR)))) { (events & (uint32_t)(EPOLLHUP | EPOLLRDHUP | EPOLLERR)))) {
if (conn->recv_message_type == ParseType::kTcpMsg) { if (conn->recv_message_type == ParseType::kTcpMsg) {
MS_LOG(INFO) << "Event value fd: " << fd << ", events: " << events << ", state: " << conn->state if (kPrintCount++ % kPrintCountInterval == 0) {
<< ", errcode: " << conn->error_code << ", errno: " << errno << ", to: " << conn->destination.c_str() MS_LOG(INFO) << "Event value fd: " << fd << ", events: " << events << ", state: " << conn->state
<< ", type:" << conn->recv_message_type << ", remote: " << conn->is_remote; << ", errcode: " << conn->error_code << ", errno: " << errno
<< ", to: " << conn->destination.c_str() << ", type:" << conn->recv_message_type
<< ", remote: " << conn->is_remote;
usleep(kPrintTimeInterval);
}
} }
conn->state = ConnectionState::kDisconnecting; conn->state = ConnectionState::kDisconnecting;
if (conn->event_callback != nullptr) { if (conn->event_callback != nullptr) {

View File

@ -982,7 +982,7 @@ void GraphSplitter::SplitGraph(const std::vector<SplitGraphSegment> &segments,
InOutDegreeList in_out_degree_list = GenerateInOutDegreeList(segments, comm_edges); InOutDegreeList in_out_degree_list = GenerateInOutDegreeList(segments, comm_edges);
if (in_out_degree_list.empty()) { if (in_out_degree_list.empty()) {
MS_LOG(WARNING) << "After splitting, this process has no graph on it. So optimize out the whole graph."; MS_LOG(WARNING) << "After splitting, this process has no graph on it. So optimize out the whole graph.";
auto return_value_node = CreateFakeValueNode(false); auto return_value_node = CreateReplacedOutputNode(func_graph_, func_graph_->output());
(void)func_graph_->manager()->Replace(func_graph_->output(), return_value_node); (void)func_graph_->manager()->Replace(func_graph_->output(), return_value_node);
return; return;
} }
@ -995,6 +995,13 @@ void GraphSplitter::SplitGraph(const std::vector<SplitGraphSegment> &segments,
} }
void GraphSplitter::SplitGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) { void GraphSplitter::SplitGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
if (fused_inter_process_op_pairs.empty()) {
MS_LOG(WARNING) << "After splitting, this process has no graph on it. So optimize out the whole graph.";
auto return_value_node = CreateReplacedOutputNode(func_graph_, func_graph_->output());
(void)func_graph_->manager()->Replace(func_graph_->output(), return_value_node);
return;
}
// Step 1: Replace origin nodes with recv nodes. // Step 1: Replace origin nodes with recv nodes.
ReplaceOriginNodesWithRecv(fused_inter_process_op_pairs); ReplaceOriginNodesWithRecv(fused_inter_process_op_pairs);

View File

@ -557,7 +557,8 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_participation_time_level", &PSContext::set_participation_time_level, "Set participation time level.") .def("set_participation_time_level", &PSContext::set_participation_time_level, "Set participation time level.")
.def("participation_time_level", &PSContext::participation_time_level, "Get participation time level.") .def("participation_time_level", &PSContext::participation_time_level, "Get participation time level.")
.def("set_continuous_failure_times", &PSContext::set_continuous_failure_times, "Set continuous failure times") .def("set_continuous_failure_times", &PSContext::set_continuous_failure_times, "Set continuous failure times")
.def("continuous_failure_times", &PSContext::continuous_failure_times, "Get continuous failure times."); .def("continuous_failure_times", &PSContext::continuous_failure_times, "Get continuous failure times.")
.def("enable_distributed_mindrt", &PSContext::enable_distributed_mindrt, "Whether distributed MindRT is enabled.");
(void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data."); (void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data.");
(void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data."); (void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data.");
(void)m.def("_is_cipher_file", &mindspore::pipeline::PyIsCipherFile, "Determine whether the file is encrypted"); (void)m.def("_is_cipher_file", &mindspore::pipeline::PyIsCipherFile, "Determine whether the file is encrypted");

View File

@ -735,18 +735,6 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
std::string backend = MsContext::GetInstance()->backend_policy(); std::string backend = MsContext::GetInstance()->backend_policy();
#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__)) #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
const std::string &server_mode = ps::PSContext::instance()->server_mode();
if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) &&
ps::PSContext::instance()->is_server()) {
return ServerPipeline(resource);
}
if (ps::PSContext::instance()->is_server()) {
resource->SetResult(kBackend, compile::CreateBackend());
return PServerPipeline(resource);
}
if (ps::PSContext::instance()->is_scheduler()) {
return PSchedulerPipeline(resource);
}
if (distributed::cluster::ClusterContext::instance()->initialized()) { if (distributed::cluster::ClusterContext::instance()->initialized()) {
auto node = distributed::cluster::ClusterContext::instance()->node(); auto node = distributed::cluster::ClusterContext::instance()->node();
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
@ -758,6 +746,19 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
} else if (cluster_ctx->node_role() == distributed::kEnvRoleOfScheduler) { } else if (cluster_ctx->node_role() == distributed::kEnvRoleOfScheduler) {
return PSchedulerPipeline(resource); return PSchedulerPipeline(resource);
} }
} else {
const std::string &server_mode = ps::PSContext::instance()->server_mode();
if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) &&
ps::PSContext::instance()->is_server()) {
return ServerPipeline(resource);
}
if (ps::PSContext::instance()->is_server()) {
resource->SetResult(kBackend, compile::CreateBackend());
return PServerPipeline(resource);
}
if (ps::PSContext::instance()->is_scheduler()) {
return PSchedulerPipeline(resource);
}
} }
#endif #endif

View File

@ -583,5 +583,9 @@ void PSContext::set_continuous_failure_times(uint32_t continuous_failure_times)
} }
uint32_t PSContext::continuous_failure_times() { return continuous_failure_times_; } uint32_t PSContext::continuous_failure_times() { return continuous_failure_times_; }
bool PSContext::enable_distributed_mindrt() const {
return distributed::cluster::ClusterContext::instance()->initialized();
}
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

View File

@ -253,6 +253,9 @@ class BACKEND_EXPORT PSContext {
void set_continuous_failure_times(uint32_t continuous_failure_times); void set_continuous_failure_times(uint32_t continuous_failure_times);
uint32_t continuous_failure_times(); uint32_t continuous_failure_times();
// Whether distributed MindRT is enabled.
bool enable_distributed_mindrt() const;
private: private:
PSContext() PSContext()
: ps_enabled_(false), : ps_enabled_(false),

View File

@ -25,6 +25,12 @@
namespace mindspore { namespace mindspore {
namespace runtime { namespace runtime {
RecvActor::~RecvActor() {
if (server_) {
server_->Finalize();
}
}
void RecvActor::SetOpcontext(OpContext<DeviceTensor> *const op_context) { void RecvActor::SetOpcontext(OpContext<DeviceTensor> *const op_context) {
std::unique_lock<std::mutex> lock(context_mtx_); std::unique_lock<std::mutex> lock(context_mtx_);
MS_EXCEPTION_IF_NULL(op_context); MS_EXCEPTION_IF_NULL(op_context);

View File

@ -36,8 +36,11 @@ class RecvActor : public RpcActor {
const std::set<size_t> &modifiable_ref_output_indexes) const std::set<size_t> &modifiable_ref_output_indexes)
: RpcActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy, : RpcActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kRecvActor), modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kRecvActor),
ip_(""),
port_(0),
server_(nullptr),
is_context_valid_(false) {} is_context_valid_(false) {}
~RecvActor() override = default; ~RecvActor() override;
// Besides set the op context, this method also notify the message handler to 'RunOpInterProcessData'. // Besides set the op context, this method also notify the message handler to 'RunOpInterProcessData'.
void SetOpcontext(OpContext<DeviceTensor> *const op_context) override; void SetOpcontext(OpContext<DeviceTensor> *const op_context) override;

View File

@ -20,6 +20,13 @@
namespace mindspore { namespace mindspore {
namespace runtime { namespace runtime {
SendActor::~SendActor() {
if (client_) {
client_->Disconnect(server_url_);
client_->Finalize();
}
}
void SendActor::SetRouteInfo(uint32_t, const std::string &, const std::string &send_src_node_name, void SendActor::SetRouteInfo(uint32_t, const std::string &, const std::string &send_src_node_name,
const std::string &send_dst_node_name) { const std::string &send_dst_node_name) {
auto peer_actor_id = inter_process_edge_name_; auto peer_actor_id = inter_process_edge_name_;
@ -38,12 +45,12 @@ bool SendActor::ConnectServer() {
MS_EXCEPTION_IF_NULL(actor_route_table_proxy_); MS_EXCEPTION_IF_NULL(actor_route_table_proxy_);
auto peer_actor_address = actor_route_table_proxy_->LookupRoute(peer_actor_id); auto peer_actor_address = actor_route_table_proxy_->LookupRoute(peer_actor_id);
// If route is successfully looked up, peer_actor_address is not empty. // If route is successfully looked up, peer_actor_address is not empty.
std::string server_url = peer_actor_address.ip() + ":" + std::to_string(peer_actor_address.port()); server_url_ = peer_actor_address.ip() + ":" + std::to_string(peer_actor_address.port());
if (!client_->Connect(server_url)) { if (!client_->Connect(server_url_)) {
MS_LOG(EXCEPTION) << "Failed to connect to server of actor " << peer_actor_id << ", server_url: " << server_url; MS_LOG(EXCEPTION) << "Failed to connect to server of actor " << peer_actor_id << ", server_url: " << server_url_;
} }
MS_LOG(INFO) << "Successfully connect to server " << server_url << ", inter-process edge name: " << peer_actor_id; MS_LOG(INFO) << "Successfully connect to server " << server_url_ << ", inter-process edge name: " << peer_actor_id;
peer_actor_urls_[peer_actor_id] = server_url; peer_actor_urls_[peer_actor_id] = server_url_;
} }
return true; return true;
} }

View File

@ -33,8 +33,10 @@ class SendActor : public RpcActor {
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes, GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
const std::set<size_t> &modifiable_ref_output_indexes) const std::set<size_t> &modifiable_ref_output_indexes)
: RpcActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy, : RpcActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kSendActor) {} modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kSendActor),
~SendActor() override = default; server_url_(""),
client_(nullptr) {}
~SendActor() override;
// Set send actor's destination peer info, in another word, send actor's output. // Set send actor's destination peer info, in another word, send actor's output.
void SetRouteInfo(uint32_t dst_rank, const std::string &dst_role, const std::string &send_src_node_name, void SetRouteInfo(uint32_t dst_rank, const std::string &dst_role, const std::string &send_src_node_name,
@ -57,6 +59,9 @@ class SendActor : public RpcActor {
std::vector<std::string> peer_actor_ids_; std::vector<std::string> peer_actor_ids_;
mindspore::HashMap<std::string, std::string> peer_actor_urls_; mindspore::HashMap<std::string, std::string> peer_actor_urls_;
// The url of the peer recv actor's tcp server.
std::string server_url_;
std::unique_ptr<TCPClient> client_; std::unique_ptr<TCPClient> client_;
}; };

View File

@ -52,7 +52,7 @@ import mindspore._c_dataengine as cde
from mindspore._c_expression import typing from mindspore._c_expression import typing
from mindspore import log as logger from mindspore import log as logger
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched, _get_ps_context, _enable_distributed_mindrt
from mindspore.dataset.engine.offload import GetOffloadModel from mindspore.dataset.engine.offload import GetOffloadModel
import mindspore.dataset.transforms.c_transforms as c_transforms import mindspore.dataset.transforms.c_transforms as c_transforms
@ -1899,6 +1899,19 @@ class Dataset:
def parse(self, children=None): def parse(self, children=None):
raise NotImplementedError("Dataset has to implement parse method.") raise NotImplementedError("Dataset has to implement parse method.")
@staticmethod
def _update_data_shard(num_shards, shard_id):
"""
Update the shard number and shard id if necessary.
This is normally used in distributed training mode like Parameter Server training.
"""
# If this is in distributed execution mode,
# the shard number and shard id might need to be updated according to the process's rank or role.
if _enable_distributed_mindrt() and _is_role_pserver():
num_shards = _get_ps_context("worker_num")
shard_id = 0
return num_shards, shard_id
def post_parse(self, ir_node): def post_parse(self, ir_node):
if self.cache: if self.cache:
ir_node = ir_node.set_cache_client(self.cache.cache_client) ir_node = ir_node.set_cache_client(self.cache.cache_client)
@ -2163,6 +2176,7 @@ class MappableDataset(SourceDataset):
def __init__(self, num_parallel_workers=None, sampler=None, num_samples=None, shuffle=None, num_shards=None, def __init__(self, num_parallel_workers=None, sampler=None, num_samples=None, shuffle=None, num_shards=None,
shard_id=None, cache=None): shard_id=None, cache=None):
num_shards, shard_id = self._update_data_shard(num_shards, shard_id)
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
num_shards=num_shards, shard_id=shard_id, cache=cache) num_shards=num_shards, shard_id=shard_id, cache=cache)
self.shuffle_flag = replace_none(shuffle, True) self.shuffle_flag = replace_none(shuffle, True)

View File

@ -302,6 +302,14 @@ def _is_fl_mode():
return _get_ps_context("server_mode") in ("FEDERATED_LEARNING", "HYBRID_TRAINING") return _get_ps_context("server_mode") in ("FEDERATED_LEARNING", "HYBRID_TRAINING")
def _enable_distributed_mindrt():
'''
Whether the distributed MindRT is enabled.
This method is used to distinguish from old distributed training mode.
'''
return ps_context().enable_distributed_mindrt()
def _check_value(key, value): def _check_value(key, value):
""" """
Validate the value for parameter server context keys. Validate the value for parameter server context keys.