diff --git a/mindspore/ccsrc/distributed/rpc/tcp/connection.cc b/mindspore/ccsrc/distributed/rpc/tcp/connection.cc index 51658fd9e95..4636b38e3aa 100644 --- a/mindspore/ccsrc/distributed/rpc/tcp/connection.cc +++ b/mindspore/ccsrc/distributed/rpc/tcp/connection.cc @@ -25,6 +25,11 @@ namespace mindspore { namespace distributed { namespace rpc { +// Print error message every 1000 times and sleep for 5ms in case the log file is too large. +static size_t kPrintCount = 0; +size_t kPrintCountInterval = 1000; +size_t kPrintTimeInterval = 50000; + // Handle socket events like read/write. void SocketEventHandler(int fd, uint32_t events, void *context) { Connection *conn = reinterpret_cast(context); @@ -61,9 +66,13 @@ void SocketEventHandler(int fd, uint32_t events, void *context) { (conn->recv_message_type != ParseType::kHttpReq && conn->recv_message_type != ParseType::kHttpRsp && (events & (uint32_t)(EPOLLHUP | EPOLLRDHUP | EPOLLERR)))) { if (conn->recv_message_type == ParseType::kTcpMsg) { - MS_LOG(INFO) << "Event value fd: " << fd << ", events: " << events << ", state: " << conn->state - << ", errcode: " << conn->error_code << ", errno: " << errno << ", to: " << conn->destination.c_str() - << ", type:" << conn->recv_message_type << ", remote: " << conn->is_remote; + if (kPrintCount++ % kPrintCountInterval == 0) { + MS_LOG(INFO) << "Event value fd: " << fd << ", events: " << events << ", state: " << conn->state + << ", errcode: " << conn->error_code << ", errno: " << errno + << ", to: " << conn->destination.c_str() << ", type:" << conn->recv_message_type + << ", remote: " << conn->is_remote; + usleep(kPrintTimeInterval); + } } conn->state = ConnectionState::kDisconnecting; if (conn->event_callback != nullptr) { diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/graph_splitter.cc b/mindspore/ccsrc/frontend/parallel/graph_util/graph_splitter.cc index 0a08c2131cc..61eaa9c9891 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/graph_splitter.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/graph_splitter.cc @@ -982,7 +982,7 @@ void GraphSplitter::SplitGraph(const std::vector &segments, InOutDegreeList in_out_degree_list = GenerateInOutDegreeList(segments, comm_edges); if (in_out_degree_list.empty()) { MS_LOG(WARNING) << "After splitting, this process has no graph on it. So optimize out the whole graph."; - auto return_value_node = CreateFakeValueNode(false); + auto return_value_node = CreateReplacedOutputNode(func_graph_, func_graph_->output()); (void)func_graph_->manager()->Replace(func_graph_->output(), return_value_node); return; } @@ -995,6 +995,13 @@ void GraphSplitter::SplitGraph(const std::vector &segments, } void GraphSplitter::SplitGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) { + if (fused_inter_process_op_pairs.empty()) { + MS_LOG(WARNING) << "After splitting, this process has no graph on it. So optimize out the whole graph."; + auto return_value_node = CreateReplacedOutputNode(func_graph_, func_graph_->output()); + (void)func_graph_->manager()->Replace(func_graph_->output(), return_value_node); + return; + } + // Step 1: Replace origin nodes with recv nodes. ReplaceOriginNodesWithRecv(fused_inter_process_op_pairs); diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 332c145eb86..66ed3e5b6f3 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -557,7 +557,8 @@ PYBIND11_MODULE(_c_expression, m) { .def("set_participation_time_level", &PSContext::set_participation_time_level, "Set participation time level.") .def("participation_time_level", &PSContext::participation_time_level, "Get participation time level.") .def("set_continuous_failure_times", &PSContext::set_continuous_failure_times, "Set continuous failure times") - .def("continuous_failure_times", &PSContext::continuous_failure_times, "Get continuous failure times."); + .def("continuous_failure_times", &PSContext::continuous_failure_times, "Get continuous failure times.") + .def("enable_distributed_mindrt", &PSContext::enable_distributed_mindrt, "Whether distributed MindRT is enabled."); (void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data."); (void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data."); (void)m.def("_is_cipher_file", &mindspore::pipeline::PyIsCipherFile, "Determine whether the file is encrypted"); diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 61c709f73e5..94b4c74ea08 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -735,18 +735,6 @@ std::vector GetPipeline(const ResourcePtr &resource, const std::stri std::string backend = MsContext::GetInstance()->backend_policy(); #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__)) - const std::string &server_mode = ps::PSContext::instance()->server_mode(); - if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) && - ps::PSContext::instance()->is_server()) { - return ServerPipeline(resource); - } - if (ps::PSContext::instance()->is_server()) { - resource->SetResult(kBackend, compile::CreateBackend()); - return PServerPipeline(resource); - } - if (ps::PSContext::instance()->is_scheduler()) { - return PSchedulerPipeline(resource); - } if (distributed::cluster::ClusterContext::instance()->initialized()) { auto node = distributed::cluster::ClusterContext::instance()->node(); MS_EXCEPTION_IF_NULL(node); @@ -758,6 +746,19 @@ std::vector GetPipeline(const ResourcePtr &resource, const std::stri } else if (cluster_ctx->node_role() == distributed::kEnvRoleOfScheduler) { return PSchedulerPipeline(resource); } + } else { + const std::string &server_mode = ps::PSContext::instance()->server_mode(); + if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) && + ps::PSContext::instance()->is_server()) { + return ServerPipeline(resource); + } + if (ps::PSContext::instance()->is_server()) { + resource->SetResult(kBackend, compile::CreateBackend()); + return PServerPipeline(resource); + } + if (ps::PSContext::instance()->is_scheduler()) { + return PSchedulerPipeline(resource); + } } #endif diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index 7c06666d7ec..5954414610f 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -583,5 +583,9 @@ void PSContext::set_continuous_failure_times(uint32_t continuous_failure_times) } uint32_t PSContext::continuous_failure_times() { return continuous_failure_times_; } + +bool PSContext::enable_distributed_mindrt() const { + return distributed::cluster::ClusterContext::instance()->initialized(); +} } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_context.h b/mindspore/ccsrc/ps/ps_context.h index 9042461925b..6abab91bd39 100644 --- a/mindspore/ccsrc/ps/ps_context.h +++ b/mindspore/ccsrc/ps/ps_context.h @@ -253,6 +253,9 @@ class BACKEND_EXPORT PSContext { void set_continuous_failure_times(uint32_t continuous_failure_times); uint32_t continuous_failure_times(); + // Whether distributed MindRT is enabled. + bool enable_distributed_mindrt() const; + private: PSContext() : ps_enabled_(false), diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/recv_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/recv_actor.cc index 64be19598fa..baaf95ca783 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/recv_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/recv_actor.cc @@ -25,6 +25,12 @@ namespace mindspore { namespace runtime { +RecvActor::~RecvActor() { + if (server_) { + server_->Finalize(); + } +} + void RecvActor::SetOpcontext(OpContext *const op_context) { std::unique_lock lock(context_mtx_); MS_EXCEPTION_IF_NULL(op_context); diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/recv_actor.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/recv_actor.h index 99493ba16ca..1ae4461323e 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/recv_actor.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/recv_actor.h @@ -36,8 +36,11 @@ class RecvActor : public RpcActor { const std::set &modifiable_ref_output_indexes) : RpcActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy, modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kRecvActor), + ip_(""), + port_(0), + server_(nullptr), is_context_valid_(false) {} - ~RecvActor() override = default; + ~RecvActor() override; // Besides set the op context, this method also notify the message handler to 'RunOpInterProcessData'. void SetOpcontext(OpContext *const op_context) override; diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.cc index be8a5de3457..8a976cf506c 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.cc @@ -20,6 +20,13 @@ namespace mindspore { namespace runtime { +SendActor::~SendActor() { + if (client_) { + client_->Disconnect(server_url_); + client_->Finalize(); + } +} + void SendActor::SetRouteInfo(uint32_t, const std::string &, const std::string &send_src_node_name, const std::string &send_dst_node_name) { auto peer_actor_id = inter_process_edge_name_; @@ -38,12 +45,12 @@ bool SendActor::ConnectServer() { MS_EXCEPTION_IF_NULL(actor_route_table_proxy_); auto peer_actor_address = actor_route_table_proxy_->LookupRoute(peer_actor_id); // If route is successfully looked up, peer_actor_address is not empty. - std::string server_url = peer_actor_address.ip() + ":" + std::to_string(peer_actor_address.port()); - if (!client_->Connect(server_url)) { - MS_LOG(EXCEPTION) << "Failed to connect to server of actor " << peer_actor_id << ", server_url: " << server_url; + server_url_ = peer_actor_address.ip() + ":" + std::to_string(peer_actor_address.port()); + if (!client_->Connect(server_url_)) { + MS_LOG(EXCEPTION) << "Failed to connect to server of actor " << peer_actor_id << ", server_url: " << server_url_; } - MS_LOG(INFO) << "Successfully connect to server " << server_url << ", inter-process edge name: " << peer_actor_id; - peer_actor_urls_[peer_actor_id] = server_url; + MS_LOG(INFO) << "Successfully connect to server " << server_url_ << ", inter-process edge name: " << peer_actor_id; + peer_actor_urls_[peer_actor_id] = server_url_; } return true; } diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.h index e2e3952b38e..1c009ae2c80 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.h @@ -33,8 +33,10 @@ class SendActor : public RpcActor { GraphExecutionStrategy strategy, const std::set &modifiable_ref_input_indexes, const std::set &modifiable_ref_output_indexes) : RpcActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy, - modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kSendActor) {} - ~SendActor() override = default; + modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kSendActor), + server_url_(""), + client_(nullptr) {} + ~SendActor() override; // Set send actor's destination peer info, in another word, send actor's output. void SetRouteInfo(uint32_t dst_rank, const std::string &dst_role, const std::string &send_src_node_name, @@ -57,6 +59,9 @@ class SendActor : public RpcActor { std::vector peer_actor_ids_; mindspore::HashMap peer_actor_urls_; + // The url of the peer recv actor's tcp server. + std::string server_url_; + std::unique_ptr client_; }; diff --git a/mindspore/python/mindspore/dataset/engine/datasets.py b/mindspore/python/mindspore/dataset/engine/datasets.py index 944dd402004..8ee4c13f47b 100644 --- a/mindspore/python/mindspore/dataset/engine/datasets.py +++ b/mindspore/python/mindspore/dataset/engine/datasets.py @@ -52,7 +52,7 @@ import mindspore._c_dataengine as cde from mindspore._c_expression import typing from mindspore import log as logger -from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched +from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched, _get_ps_context, _enable_distributed_mindrt from mindspore.dataset.engine.offload import GetOffloadModel import mindspore.dataset.transforms.c_transforms as c_transforms @@ -1899,6 +1899,19 @@ class Dataset: def parse(self, children=None): raise NotImplementedError("Dataset has to implement parse method.") + @staticmethod + def _update_data_shard(num_shards, shard_id): + """ + Update the shard number and shard id if necessary. + This is normally used in distributed training mode like Parameter Server training. + """ + # If this is in distributed execution mode, + # the shard number and shard id might need to be updated according to the process's rank or role. + if _enable_distributed_mindrt() and _is_role_pserver(): + num_shards = _get_ps_context("worker_num") + shard_id = 0 + return num_shards, shard_id + def post_parse(self, ir_node): if self.cache: ir_node = ir_node.set_cache_client(self.cache.cache_client) @@ -2163,6 +2176,7 @@ class MappableDataset(SourceDataset): def __init__(self, num_parallel_workers=None, sampler=None, num_samples=None, shuffle=None, num_shards=None, shard_id=None, cache=None): + num_shards, shard_id = self._update_data_shard(num_shards, shard_id) super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) self.shuffle_flag = replace_none(shuffle, True) diff --git a/mindspore/python/mindspore/parallel/_ps_context.py b/mindspore/python/mindspore/parallel/_ps_context.py index 117b62e6179..aa762da822a 100644 --- a/mindspore/python/mindspore/parallel/_ps_context.py +++ b/mindspore/python/mindspore/parallel/_ps_context.py @@ -302,6 +302,14 @@ def _is_fl_mode(): return _get_ps_context("server_mode") in ("FEDERATED_LEARNING", "HYBRID_TRAINING") +def _enable_distributed_mindrt(): + ''' + Whether the distributed MindRT is enabled. + This method is used to distinguish from old distributed training mode. + ''' + return ps_context().enable_distributed_mindrt() + + def _check_value(key, value): """ Validate the value for parameter server context keys.