Adapt for ps test case

This commit is contained in:
ZPaC 2022-04-29 18:43:53 +08:00
parent c29d6bb764
commit 78bd48fc57
12 changed files with 94 additions and 26 deletions

View File

@ -25,6 +25,11 @@
namespace mindspore {
namespace distributed {
namespace rpc {
// Print error message every 1000 times and sleep for 5ms in case the log file is too large.
static size_t kPrintCount = 0;
size_t kPrintCountInterval = 1000;
size_t kPrintTimeInterval = 50000;
// Handle socket events like read/write.
void SocketEventHandler(int fd, uint32_t events, void *context) {
Connection *conn = reinterpret_cast<Connection *>(context);
@ -61,9 +66,13 @@ void SocketEventHandler(int fd, uint32_t events, void *context) {
(conn->recv_message_type != ParseType::kHttpReq && conn->recv_message_type != ParseType::kHttpRsp &&
(events & (uint32_t)(EPOLLHUP | EPOLLRDHUP | EPOLLERR)))) {
if (conn->recv_message_type == ParseType::kTcpMsg) {
if (kPrintCount++ % kPrintCountInterval == 0) {
MS_LOG(INFO) << "Event value fd: " << fd << ", events: " << events << ", state: " << conn->state
<< ", errcode: " << conn->error_code << ", errno: " << errno << ", to: " << conn->destination.c_str()
<< ", type:" << conn->recv_message_type << ", remote: " << conn->is_remote;
<< ", errcode: " << conn->error_code << ", errno: " << errno
<< ", to: " << conn->destination.c_str() << ", type:" << conn->recv_message_type
<< ", remote: " << conn->is_remote;
usleep(kPrintTimeInterval);
}
}
conn->state = ConnectionState::kDisconnecting;
if (conn->event_callback != nullptr) {

View File

@ -982,7 +982,7 @@ void GraphSplitter::SplitGraph(const std::vector<SplitGraphSegment> &segments,
InOutDegreeList in_out_degree_list = GenerateInOutDegreeList(segments, comm_edges);
if (in_out_degree_list.empty()) {
MS_LOG(WARNING) << "After splitting, this process has no graph on it. So optimize out the whole graph.";
auto return_value_node = CreateFakeValueNode(false);
auto return_value_node = CreateReplacedOutputNode(func_graph_, func_graph_->output());
(void)func_graph_->manager()->Replace(func_graph_->output(), return_value_node);
return;
}
@ -995,6 +995,13 @@ void GraphSplitter::SplitGraph(const std::vector<SplitGraphSegment> &segments,
}
void GraphSplitter::SplitGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
if (fused_inter_process_op_pairs.empty()) {
MS_LOG(WARNING) << "After splitting, this process has no graph on it. So optimize out the whole graph.";
auto return_value_node = CreateReplacedOutputNode(func_graph_, func_graph_->output());
(void)func_graph_->manager()->Replace(func_graph_->output(), return_value_node);
return;
}
// Step 1: Replace origin nodes with recv nodes.
ReplaceOriginNodesWithRecv(fused_inter_process_op_pairs);

View File

@ -557,7 +557,8 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_participation_time_level", &PSContext::set_participation_time_level, "Set participation time level.")
.def("participation_time_level", &PSContext::participation_time_level, "Get participation time level.")
.def("set_continuous_failure_times", &PSContext::set_continuous_failure_times, "Set continuous failure times")
.def("continuous_failure_times", &PSContext::continuous_failure_times, "Get continuous failure times.");
.def("continuous_failure_times", &PSContext::continuous_failure_times, "Get continuous failure times.")
.def("enable_distributed_mindrt", &PSContext::enable_distributed_mindrt, "Whether distributed MindRT is enabled.");
(void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data.");
(void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data.");
(void)m.def("_is_cipher_file", &mindspore::pipeline::PyIsCipherFile, "Determine whether the file is encrypted");

View File

@ -735,6 +735,18 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
std::string backend = MsContext::GetInstance()->backend_policy();
#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
if (distributed::cluster::ClusterContext::instance()->initialized()) {
auto node = distributed::cluster::ClusterContext::instance()->node();
MS_EXCEPTION_IF_NULL(node);
const auto &cluster_ctx = distributed::cluster::ClusterContext::instance();
MS_EXCEPTION_IF_NULL(cluster_ctx);
MS_LOG(INFO) << "Cluster is initialized. This node role is " << cluster_ctx->node_role();
if (cluster_ctx->node_role() == distributed::kEnvRoleOfServer) {
return PServerPipeline(resource);
} else if (cluster_ctx->node_role() == distributed::kEnvRoleOfScheduler) {
return PSchedulerPipeline(resource);
}
} else {
const std::string &server_mode = ps::PSContext::instance()->server_mode();
if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) &&
ps::PSContext::instance()->is_server()) {
@ -747,17 +759,6 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
if (ps::PSContext::instance()->is_scheduler()) {
return PSchedulerPipeline(resource);
}
if (distributed::cluster::ClusterContext::instance()->initialized()) {
auto node = distributed::cluster::ClusterContext::instance()->node();
MS_EXCEPTION_IF_NULL(node);
const auto &cluster_ctx = distributed::cluster::ClusterContext::instance();
MS_EXCEPTION_IF_NULL(cluster_ctx);
MS_LOG(INFO) << "Cluster is initialized. This node role is " << cluster_ctx->node_role();
if (cluster_ctx->node_role() == distributed::kEnvRoleOfServer) {
return PServerPipeline(resource);
} else if (cluster_ctx->node_role() == distributed::kEnvRoleOfScheduler) {
return PSchedulerPipeline(resource);
}
}
#endif

View File

@ -583,5 +583,9 @@ void PSContext::set_continuous_failure_times(uint32_t continuous_failure_times)
}
uint32_t PSContext::continuous_failure_times() { return continuous_failure_times_; }
bool PSContext::enable_distributed_mindrt() const {
return distributed::cluster::ClusterContext::instance()->initialized();
}
} // namespace ps
} // namespace mindspore

View File

@ -253,6 +253,9 @@ class BACKEND_EXPORT PSContext {
void set_continuous_failure_times(uint32_t continuous_failure_times);
uint32_t continuous_failure_times();
// Whether distributed MindRT is enabled.
bool enable_distributed_mindrt() const;
private:
PSContext()
: ps_enabled_(false),

View File

@ -25,6 +25,12 @@
namespace mindspore {
namespace runtime {
RecvActor::~RecvActor() {
if (server_) {
server_->Finalize();
}
}
void RecvActor::SetOpcontext(OpContext<DeviceTensor> *const op_context) {
std::unique_lock<std::mutex> lock(context_mtx_);
MS_EXCEPTION_IF_NULL(op_context);

View File

@ -36,8 +36,11 @@ class RecvActor : public RpcActor {
const std::set<size_t> &modifiable_ref_output_indexes)
: RpcActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kRecvActor),
ip_(""),
port_(0),
server_(nullptr),
is_context_valid_(false) {}
~RecvActor() override = default;
~RecvActor() override;
// Besides set the op context, this method also notify the message handler to 'RunOpInterProcessData'.
void SetOpcontext(OpContext<DeviceTensor> *const op_context) override;

View File

@ -20,6 +20,13 @@
namespace mindspore {
namespace runtime {
SendActor::~SendActor() {
if (client_) {
client_->Disconnect(server_url_);
client_->Finalize();
}
}
void SendActor::SetRouteInfo(uint32_t, const std::string &, const std::string &send_src_node_name,
const std::string &send_dst_node_name) {
auto peer_actor_id = inter_process_edge_name_;
@ -38,12 +45,12 @@ bool SendActor::ConnectServer() {
MS_EXCEPTION_IF_NULL(actor_route_table_proxy_);
auto peer_actor_address = actor_route_table_proxy_->LookupRoute(peer_actor_id);
// If route is successfully looked up, peer_actor_address is not empty.
std::string server_url = peer_actor_address.ip() + ":" + std::to_string(peer_actor_address.port());
if (!client_->Connect(server_url)) {
MS_LOG(EXCEPTION) << "Failed to connect to server of actor " << peer_actor_id << ", server_url: " << server_url;
server_url_ = peer_actor_address.ip() + ":" + std::to_string(peer_actor_address.port());
if (!client_->Connect(server_url_)) {
MS_LOG(EXCEPTION) << "Failed to connect to server of actor " << peer_actor_id << ", server_url: " << server_url_;
}
MS_LOG(INFO) << "Successfully connect to server " << server_url << ", inter-process edge name: " << peer_actor_id;
peer_actor_urls_[peer_actor_id] = server_url;
MS_LOG(INFO) << "Successfully connect to server " << server_url_ << ", inter-process edge name: " << peer_actor_id;
peer_actor_urls_[peer_actor_id] = server_url_;
}
return true;
}

View File

@ -33,8 +33,10 @@ class SendActor : public RpcActor {
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
const std::set<size_t> &modifiable_ref_output_indexes)
: RpcActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kSendActor) {}
~SendActor() override = default;
modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kSendActor),
server_url_(""),
client_(nullptr) {}
~SendActor() override;
// Set send actor's destination peer info, in another word, send actor's output.
void SetRouteInfo(uint32_t dst_rank, const std::string &dst_role, const std::string &send_src_node_name,
@ -57,6 +59,9 @@ class SendActor : public RpcActor {
std::vector<std::string> peer_actor_ids_;
mindspore::HashMap<std::string, std::string> peer_actor_urls_;
// The url of the peer recv actor's tcp server.
std::string server_url_;
std::unique_ptr<TCPClient> client_;
};

View File

@ -52,7 +52,7 @@ import mindspore._c_dataengine as cde
from mindspore._c_expression import typing
from mindspore import log as logger
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched, _get_ps_context, _enable_distributed_mindrt
from mindspore.dataset.engine.offload import GetOffloadModel
import mindspore.dataset.transforms.c_transforms as c_transforms
@ -1899,6 +1899,19 @@ class Dataset:
def parse(self, children=None):
raise NotImplementedError("Dataset has to implement parse method.")
@staticmethod
def _update_data_shard(num_shards, shard_id):
"""
Update the shard number and shard id if necessary.
This is normally used in distributed training mode like Parameter Server training.
"""
# If this is in distributed execution mode,
# the shard number and shard id might need to be updated according to the process's rank or role.
if _enable_distributed_mindrt() and _is_role_pserver():
num_shards = _get_ps_context("worker_num")
shard_id = 0
return num_shards, shard_id
def post_parse(self, ir_node):
if self.cache:
ir_node = ir_node.set_cache_client(self.cache.cache_client)
@ -2163,6 +2176,7 @@ class MappableDataset(SourceDataset):
def __init__(self, num_parallel_workers=None, sampler=None, num_samples=None, shuffle=None, num_shards=None,
shard_id=None, cache=None):
num_shards, shard_id = self._update_data_shard(num_shards, shard_id)
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
num_shards=num_shards, shard_id=shard_id, cache=cache)
self.shuffle_flag = replace_none(shuffle, True)

View File

@ -302,6 +302,14 @@ def _is_fl_mode():
return _get_ps_context("server_mode") in ("FEDERATED_LEARNING", "HYBRID_TRAINING")
def _enable_distributed_mindrt():
'''
Whether the distributed MindRT is enabled.
This method is used to distinguish from old distributed training mode.
'''
return ps_context().enable_distributed_mindrt()
def _check_value(key, value):
"""
Validate the value for parameter server context keys.