forked from mindspore-Ecosystem/mindspore
Adapt for ps test case
This commit is contained in:
parent
c29d6bb764
commit
78bd48fc57
|
@ -25,6 +25,11 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace distributed {
|
namespace distributed {
|
||||||
namespace rpc {
|
namespace rpc {
|
||||||
|
// Print error message every 1000 times and sleep for 5ms in case the log file is too large.
|
||||||
|
static size_t kPrintCount = 0;
|
||||||
|
size_t kPrintCountInterval = 1000;
|
||||||
|
size_t kPrintTimeInterval = 50000;
|
||||||
|
|
||||||
// Handle socket events like read/write.
|
// Handle socket events like read/write.
|
||||||
void SocketEventHandler(int fd, uint32_t events, void *context) {
|
void SocketEventHandler(int fd, uint32_t events, void *context) {
|
||||||
Connection *conn = reinterpret_cast<Connection *>(context);
|
Connection *conn = reinterpret_cast<Connection *>(context);
|
||||||
|
@ -61,9 +66,13 @@ void SocketEventHandler(int fd, uint32_t events, void *context) {
|
||||||
(conn->recv_message_type != ParseType::kHttpReq && conn->recv_message_type != ParseType::kHttpRsp &&
|
(conn->recv_message_type != ParseType::kHttpReq && conn->recv_message_type != ParseType::kHttpRsp &&
|
||||||
(events & (uint32_t)(EPOLLHUP | EPOLLRDHUP | EPOLLERR)))) {
|
(events & (uint32_t)(EPOLLHUP | EPOLLRDHUP | EPOLLERR)))) {
|
||||||
if (conn->recv_message_type == ParseType::kTcpMsg) {
|
if (conn->recv_message_type == ParseType::kTcpMsg) {
|
||||||
MS_LOG(INFO) << "Event value fd: " << fd << ", events: " << events << ", state: " << conn->state
|
if (kPrintCount++ % kPrintCountInterval == 0) {
|
||||||
<< ", errcode: " << conn->error_code << ", errno: " << errno << ", to: " << conn->destination.c_str()
|
MS_LOG(INFO) << "Event value fd: " << fd << ", events: " << events << ", state: " << conn->state
|
||||||
<< ", type:" << conn->recv_message_type << ", remote: " << conn->is_remote;
|
<< ", errcode: " << conn->error_code << ", errno: " << errno
|
||||||
|
<< ", to: " << conn->destination.c_str() << ", type:" << conn->recv_message_type
|
||||||
|
<< ", remote: " << conn->is_remote;
|
||||||
|
usleep(kPrintTimeInterval);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
conn->state = ConnectionState::kDisconnecting;
|
conn->state = ConnectionState::kDisconnecting;
|
||||||
if (conn->event_callback != nullptr) {
|
if (conn->event_callback != nullptr) {
|
||||||
|
|
|
@ -982,7 +982,7 @@ void GraphSplitter::SplitGraph(const std::vector<SplitGraphSegment> &segments,
|
||||||
InOutDegreeList in_out_degree_list = GenerateInOutDegreeList(segments, comm_edges);
|
InOutDegreeList in_out_degree_list = GenerateInOutDegreeList(segments, comm_edges);
|
||||||
if (in_out_degree_list.empty()) {
|
if (in_out_degree_list.empty()) {
|
||||||
MS_LOG(WARNING) << "After splitting, this process has no graph on it. So optimize out the whole graph.";
|
MS_LOG(WARNING) << "After splitting, this process has no graph on it. So optimize out the whole graph.";
|
||||||
auto return_value_node = CreateFakeValueNode(false);
|
auto return_value_node = CreateReplacedOutputNode(func_graph_, func_graph_->output());
|
||||||
(void)func_graph_->manager()->Replace(func_graph_->output(), return_value_node);
|
(void)func_graph_->manager()->Replace(func_graph_->output(), return_value_node);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -995,6 +995,13 @@ void GraphSplitter::SplitGraph(const std::vector<SplitGraphSegment> &segments,
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphSplitter::SplitGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
|
void GraphSplitter::SplitGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
|
||||||
|
if (fused_inter_process_op_pairs.empty()) {
|
||||||
|
MS_LOG(WARNING) << "After splitting, this process has no graph on it. So optimize out the whole graph.";
|
||||||
|
auto return_value_node = CreateReplacedOutputNode(func_graph_, func_graph_->output());
|
||||||
|
(void)func_graph_->manager()->Replace(func_graph_->output(), return_value_node);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Step 1: Replace origin nodes with recv nodes.
|
// Step 1: Replace origin nodes with recv nodes.
|
||||||
ReplaceOriginNodesWithRecv(fused_inter_process_op_pairs);
|
ReplaceOriginNodesWithRecv(fused_inter_process_op_pairs);
|
||||||
|
|
||||||
|
|
|
@ -557,7 +557,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
||||||
.def("set_participation_time_level", &PSContext::set_participation_time_level, "Set participation time level.")
|
.def("set_participation_time_level", &PSContext::set_participation_time_level, "Set participation time level.")
|
||||||
.def("participation_time_level", &PSContext::participation_time_level, "Get participation time level.")
|
.def("participation_time_level", &PSContext::participation_time_level, "Get participation time level.")
|
||||||
.def("set_continuous_failure_times", &PSContext::set_continuous_failure_times, "Set continuous failure times")
|
.def("set_continuous_failure_times", &PSContext::set_continuous_failure_times, "Set continuous failure times")
|
||||||
.def("continuous_failure_times", &PSContext::continuous_failure_times, "Get continuous failure times.");
|
.def("continuous_failure_times", &PSContext::continuous_failure_times, "Get continuous failure times.")
|
||||||
|
.def("enable_distributed_mindrt", &PSContext::enable_distributed_mindrt, "Whether distributed MindRT is enabled.");
|
||||||
(void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data.");
|
(void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data.");
|
||||||
(void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data.");
|
(void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data.");
|
||||||
(void)m.def("_is_cipher_file", &mindspore::pipeline::PyIsCipherFile, "Determine whether the file is encrypted");
|
(void)m.def("_is_cipher_file", &mindspore::pipeline::PyIsCipherFile, "Determine whether the file is encrypted");
|
||||||
|
|
|
@ -735,18 +735,6 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
|
||||||
std::string backend = MsContext::GetInstance()->backend_policy();
|
std::string backend = MsContext::GetInstance()->backend_policy();
|
||||||
|
|
||||||
#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
|
#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
|
||||||
const std::string &server_mode = ps::PSContext::instance()->server_mode();
|
|
||||||
if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) &&
|
|
||||||
ps::PSContext::instance()->is_server()) {
|
|
||||||
return ServerPipeline(resource);
|
|
||||||
}
|
|
||||||
if (ps::PSContext::instance()->is_server()) {
|
|
||||||
resource->SetResult(kBackend, compile::CreateBackend());
|
|
||||||
return PServerPipeline(resource);
|
|
||||||
}
|
|
||||||
if (ps::PSContext::instance()->is_scheduler()) {
|
|
||||||
return PSchedulerPipeline(resource);
|
|
||||||
}
|
|
||||||
if (distributed::cluster::ClusterContext::instance()->initialized()) {
|
if (distributed::cluster::ClusterContext::instance()->initialized()) {
|
||||||
auto node = distributed::cluster::ClusterContext::instance()->node();
|
auto node = distributed::cluster::ClusterContext::instance()->node();
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
@ -758,6 +746,19 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
|
||||||
} else if (cluster_ctx->node_role() == distributed::kEnvRoleOfScheduler) {
|
} else if (cluster_ctx->node_role() == distributed::kEnvRoleOfScheduler) {
|
||||||
return PSchedulerPipeline(resource);
|
return PSchedulerPipeline(resource);
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
const std::string &server_mode = ps::PSContext::instance()->server_mode();
|
||||||
|
if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) &&
|
||||||
|
ps::PSContext::instance()->is_server()) {
|
||||||
|
return ServerPipeline(resource);
|
||||||
|
}
|
||||||
|
if (ps::PSContext::instance()->is_server()) {
|
||||||
|
resource->SetResult(kBackend, compile::CreateBackend());
|
||||||
|
return PServerPipeline(resource);
|
||||||
|
}
|
||||||
|
if (ps::PSContext::instance()->is_scheduler()) {
|
||||||
|
return PSchedulerPipeline(resource);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
|
@ -583,5 +583,9 @@ void PSContext::set_continuous_failure_times(uint32_t continuous_failure_times)
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t PSContext::continuous_failure_times() { return continuous_failure_times_; }
|
uint32_t PSContext::continuous_failure_times() { return continuous_failure_times_; }
|
||||||
|
|
||||||
|
bool PSContext::enable_distributed_mindrt() const {
|
||||||
|
return distributed::cluster::ClusterContext::instance()->initialized();
|
||||||
|
}
|
||||||
} // namespace ps
|
} // namespace ps
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -253,6 +253,9 @@ class BACKEND_EXPORT PSContext {
|
||||||
void set_continuous_failure_times(uint32_t continuous_failure_times);
|
void set_continuous_failure_times(uint32_t continuous_failure_times);
|
||||||
uint32_t continuous_failure_times();
|
uint32_t continuous_failure_times();
|
||||||
|
|
||||||
|
// Whether distributed MindRT is enabled.
|
||||||
|
bool enable_distributed_mindrt() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
PSContext()
|
PSContext()
|
||||||
: ps_enabled_(false),
|
: ps_enabled_(false),
|
||||||
|
|
|
@ -25,6 +25,12 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace runtime {
|
namespace runtime {
|
||||||
|
RecvActor::~RecvActor() {
|
||||||
|
if (server_) {
|
||||||
|
server_->Finalize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void RecvActor::SetOpcontext(OpContext<DeviceTensor> *const op_context) {
|
void RecvActor::SetOpcontext(OpContext<DeviceTensor> *const op_context) {
|
||||||
std::unique_lock<std::mutex> lock(context_mtx_);
|
std::unique_lock<std::mutex> lock(context_mtx_);
|
||||||
MS_EXCEPTION_IF_NULL(op_context);
|
MS_EXCEPTION_IF_NULL(op_context);
|
||||||
|
|
|
@ -36,8 +36,11 @@ class RecvActor : public RpcActor {
|
||||||
const std::set<size_t> &modifiable_ref_output_indexes)
|
const std::set<size_t> &modifiable_ref_output_indexes)
|
||||||
: RpcActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
|
: RpcActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
|
||||||
modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kRecvActor),
|
modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kRecvActor),
|
||||||
|
ip_(""),
|
||||||
|
port_(0),
|
||||||
|
server_(nullptr),
|
||||||
is_context_valid_(false) {}
|
is_context_valid_(false) {}
|
||||||
~RecvActor() override = default;
|
~RecvActor() override;
|
||||||
|
|
||||||
// Besides set the op context, this method also notify the message handler to 'RunOpInterProcessData'.
|
// Besides set the op context, this method also notify the message handler to 'RunOpInterProcessData'.
|
||||||
void SetOpcontext(OpContext<DeviceTensor> *const op_context) override;
|
void SetOpcontext(OpContext<DeviceTensor> *const op_context) override;
|
||||||
|
|
|
@ -20,6 +20,13 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace runtime {
|
namespace runtime {
|
||||||
|
SendActor::~SendActor() {
|
||||||
|
if (client_) {
|
||||||
|
client_->Disconnect(server_url_);
|
||||||
|
client_->Finalize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void SendActor::SetRouteInfo(uint32_t, const std::string &, const std::string &send_src_node_name,
|
void SendActor::SetRouteInfo(uint32_t, const std::string &, const std::string &send_src_node_name,
|
||||||
const std::string &send_dst_node_name) {
|
const std::string &send_dst_node_name) {
|
||||||
auto peer_actor_id = inter_process_edge_name_;
|
auto peer_actor_id = inter_process_edge_name_;
|
||||||
|
@ -38,12 +45,12 @@ bool SendActor::ConnectServer() {
|
||||||
MS_EXCEPTION_IF_NULL(actor_route_table_proxy_);
|
MS_EXCEPTION_IF_NULL(actor_route_table_proxy_);
|
||||||
auto peer_actor_address = actor_route_table_proxy_->LookupRoute(peer_actor_id);
|
auto peer_actor_address = actor_route_table_proxy_->LookupRoute(peer_actor_id);
|
||||||
// If route is successfully looked up, peer_actor_address is not empty.
|
// If route is successfully looked up, peer_actor_address is not empty.
|
||||||
std::string server_url = peer_actor_address.ip() + ":" + std::to_string(peer_actor_address.port());
|
server_url_ = peer_actor_address.ip() + ":" + std::to_string(peer_actor_address.port());
|
||||||
if (!client_->Connect(server_url)) {
|
if (!client_->Connect(server_url_)) {
|
||||||
MS_LOG(EXCEPTION) << "Failed to connect to server of actor " << peer_actor_id << ", server_url: " << server_url;
|
MS_LOG(EXCEPTION) << "Failed to connect to server of actor " << peer_actor_id << ", server_url: " << server_url_;
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Successfully connect to server " << server_url << ", inter-process edge name: " << peer_actor_id;
|
MS_LOG(INFO) << "Successfully connect to server " << server_url_ << ", inter-process edge name: " << peer_actor_id;
|
||||||
peer_actor_urls_[peer_actor_id] = server_url;
|
peer_actor_urls_[peer_actor_id] = server_url_;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,8 +33,10 @@ class SendActor : public RpcActor {
|
||||||
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
|
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
|
||||||
const std::set<size_t> &modifiable_ref_output_indexes)
|
const std::set<size_t> &modifiable_ref_output_indexes)
|
||||||
: RpcActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
|
: RpcActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
|
||||||
modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kSendActor) {}
|
modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kSendActor),
|
||||||
~SendActor() override = default;
|
server_url_(""),
|
||||||
|
client_(nullptr) {}
|
||||||
|
~SendActor() override;
|
||||||
|
|
||||||
// Set send actor's destination peer info, in another word, send actor's output.
|
// Set send actor's destination peer info, in another word, send actor's output.
|
||||||
void SetRouteInfo(uint32_t dst_rank, const std::string &dst_role, const std::string &send_src_node_name,
|
void SetRouteInfo(uint32_t dst_rank, const std::string &dst_role, const std::string &send_src_node_name,
|
||||||
|
@ -57,6 +59,9 @@ class SendActor : public RpcActor {
|
||||||
std::vector<std::string> peer_actor_ids_;
|
std::vector<std::string> peer_actor_ids_;
|
||||||
mindspore::HashMap<std::string, std::string> peer_actor_urls_;
|
mindspore::HashMap<std::string, std::string> peer_actor_urls_;
|
||||||
|
|
||||||
|
// The url of the peer recv actor's tcp server.
|
||||||
|
std::string server_url_;
|
||||||
|
|
||||||
std::unique_ptr<TCPClient> client_;
|
std::unique_ptr<TCPClient> client_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -52,7 +52,7 @@ import mindspore._c_dataengine as cde
|
||||||
from mindspore._c_expression import typing
|
from mindspore._c_expression import typing
|
||||||
|
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
|
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched, _get_ps_context, _enable_distributed_mindrt
|
||||||
from mindspore.dataset.engine.offload import GetOffloadModel
|
from mindspore.dataset.engine.offload import GetOffloadModel
|
||||||
|
|
||||||
import mindspore.dataset.transforms.c_transforms as c_transforms
|
import mindspore.dataset.transforms.c_transforms as c_transforms
|
||||||
|
@ -1899,6 +1899,19 @@ class Dataset:
|
||||||
def parse(self, children=None):
|
def parse(self, children=None):
|
||||||
raise NotImplementedError("Dataset has to implement parse method.")
|
raise NotImplementedError("Dataset has to implement parse method.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _update_data_shard(num_shards, shard_id):
|
||||||
|
"""
|
||||||
|
Update the shard number and shard id if necessary.
|
||||||
|
This is normally used in distributed training mode like Parameter Server training.
|
||||||
|
"""
|
||||||
|
# If this is in distributed execution mode,
|
||||||
|
# the shard number and shard id might need to be updated according to the process's rank or role.
|
||||||
|
if _enable_distributed_mindrt() and _is_role_pserver():
|
||||||
|
num_shards = _get_ps_context("worker_num")
|
||||||
|
shard_id = 0
|
||||||
|
return num_shards, shard_id
|
||||||
|
|
||||||
def post_parse(self, ir_node):
|
def post_parse(self, ir_node):
|
||||||
if self.cache:
|
if self.cache:
|
||||||
ir_node = ir_node.set_cache_client(self.cache.cache_client)
|
ir_node = ir_node.set_cache_client(self.cache.cache_client)
|
||||||
|
@ -2163,6 +2176,7 @@ class MappableDataset(SourceDataset):
|
||||||
|
|
||||||
def __init__(self, num_parallel_workers=None, sampler=None, num_samples=None, shuffle=None, num_shards=None,
|
def __init__(self, num_parallel_workers=None, sampler=None, num_samples=None, shuffle=None, num_shards=None,
|
||||||
shard_id=None, cache=None):
|
shard_id=None, cache=None):
|
||||||
|
num_shards, shard_id = self._update_data_shard(num_shards, shard_id)
|
||||||
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
|
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
|
||||||
num_shards=num_shards, shard_id=shard_id, cache=cache)
|
num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||||
self.shuffle_flag = replace_none(shuffle, True)
|
self.shuffle_flag = replace_none(shuffle, True)
|
||||||
|
|
|
@ -302,6 +302,14 @@ def _is_fl_mode():
|
||||||
return _get_ps_context("server_mode") in ("FEDERATED_LEARNING", "HYBRID_TRAINING")
|
return _get_ps_context("server_mode") in ("FEDERATED_LEARNING", "HYBRID_TRAINING")
|
||||||
|
|
||||||
|
|
||||||
|
def _enable_distributed_mindrt():
|
||||||
|
'''
|
||||||
|
Whether the distributed MindRT is enabled.
|
||||||
|
This method is used to distinguish from old distributed training mode.
|
||||||
|
'''
|
||||||
|
return ps_context().enable_distributed_mindrt()
|
||||||
|
|
||||||
|
|
||||||
def _check_value(key, value):
|
def _check_value(key, value):
|
||||||
"""
|
"""
|
||||||
Validate the value for parameter server context keys.
|
Validate the value for parameter server context keys.
|
||||||
|
|
Loading…
Reference in New Issue