forked from mindspore-Ecosystem/mindspore
Adapt for ps test case
This commit is contained in:
parent
c29d6bb764
commit
78bd48fc57
|
@ -25,6 +25,11 @@
|
|||
namespace mindspore {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
// Print error message every 1000 times and sleep for 5ms in case the log file is too large.
|
||||
static size_t kPrintCount = 0;
|
||||
size_t kPrintCountInterval = 1000;
|
||||
size_t kPrintTimeInterval = 50000;
|
||||
|
||||
// Handle socket events like read/write.
|
||||
void SocketEventHandler(int fd, uint32_t events, void *context) {
|
||||
Connection *conn = reinterpret_cast<Connection *>(context);
|
||||
|
@ -61,9 +66,13 @@ void SocketEventHandler(int fd, uint32_t events, void *context) {
|
|||
(conn->recv_message_type != ParseType::kHttpReq && conn->recv_message_type != ParseType::kHttpRsp &&
|
||||
(events & (uint32_t)(EPOLLHUP | EPOLLRDHUP | EPOLLERR)))) {
|
||||
if (conn->recv_message_type == ParseType::kTcpMsg) {
|
||||
MS_LOG(INFO) << "Event value fd: " << fd << ", events: " << events << ", state: " << conn->state
|
||||
<< ", errcode: " << conn->error_code << ", errno: " << errno << ", to: " << conn->destination.c_str()
|
||||
<< ", type:" << conn->recv_message_type << ", remote: " << conn->is_remote;
|
||||
if (kPrintCount++ % kPrintCountInterval == 0) {
|
||||
MS_LOG(INFO) << "Event value fd: " << fd << ", events: " << events << ", state: " << conn->state
|
||||
<< ", errcode: " << conn->error_code << ", errno: " << errno
|
||||
<< ", to: " << conn->destination.c_str() << ", type:" << conn->recv_message_type
|
||||
<< ", remote: " << conn->is_remote;
|
||||
usleep(kPrintTimeInterval);
|
||||
}
|
||||
}
|
||||
conn->state = ConnectionState::kDisconnecting;
|
||||
if (conn->event_callback != nullptr) {
|
||||
|
|
|
@ -982,7 +982,7 @@ void GraphSplitter::SplitGraph(const std::vector<SplitGraphSegment> &segments,
|
|||
InOutDegreeList in_out_degree_list = GenerateInOutDegreeList(segments, comm_edges);
|
||||
if (in_out_degree_list.empty()) {
|
||||
MS_LOG(WARNING) << "After splitting, this process has no graph on it. So optimize out the whole graph.";
|
||||
auto return_value_node = CreateFakeValueNode(false);
|
||||
auto return_value_node = CreateReplacedOutputNode(func_graph_, func_graph_->output());
|
||||
(void)func_graph_->manager()->Replace(func_graph_->output(), return_value_node);
|
||||
return;
|
||||
}
|
||||
|
@ -995,6 +995,13 @@ void GraphSplitter::SplitGraph(const std::vector<SplitGraphSegment> &segments,
|
|||
}
|
||||
|
||||
void GraphSplitter::SplitGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
|
||||
if (fused_inter_process_op_pairs.empty()) {
|
||||
MS_LOG(WARNING) << "After splitting, this process has no graph on it. So optimize out the whole graph.";
|
||||
auto return_value_node = CreateReplacedOutputNode(func_graph_, func_graph_->output());
|
||||
(void)func_graph_->manager()->Replace(func_graph_->output(), return_value_node);
|
||||
return;
|
||||
}
|
||||
|
||||
// Step 1: Replace origin nodes with recv nodes.
|
||||
ReplaceOriginNodesWithRecv(fused_inter_process_op_pairs);
|
||||
|
||||
|
|
|
@ -557,7 +557,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("set_participation_time_level", &PSContext::set_participation_time_level, "Set participation time level.")
|
||||
.def("participation_time_level", &PSContext::participation_time_level, "Get participation time level.")
|
||||
.def("set_continuous_failure_times", &PSContext::set_continuous_failure_times, "Set continuous failure times")
|
||||
.def("continuous_failure_times", &PSContext::continuous_failure_times, "Get continuous failure times.");
|
||||
.def("continuous_failure_times", &PSContext::continuous_failure_times, "Get continuous failure times.")
|
||||
.def("enable_distributed_mindrt", &PSContext::enable_distributed_mindrt, "Whether distributed MindRT is enabled.");
|
||||
(void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data.");
|
||||
(void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data.");
|
||||
(void)m.def("_is_cipher_file", &mindspore::pipeline::PyIsCipherFile, "Determine whether the file is encrypted");
|
||||
|
|
|
@ -735,18 +735,6 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
|
|||
std::string backend = MsContext::GetInstance()->backend_policy();
|
||||
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
|
||||
const std::string &server_mode = ps::PSContext::instance()->server_mode();
|
||||
if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) &&
|
||||
ps::PSContext::instance()->is_server()) {
|
||||
return ServerPipeline(resource);
|
||||
}
|
||||
if (ps::PSContext::instance()->is_server()) {
|
||||
resource->SetResult(kBackend, compile::CreateBackend());
|
||||
return PServerPipeline(resource);
|
||||
}
|
||||
if (ps::PSContext::instance()->is_scheduler()) {
|
||||
return PSchedulerPipeline(resource);
|
||||
}
|
||||
if (distributed::cluster::ClusterContext::instance()->initialized()) {
|
||||
auto node = distributed::cluster::ClusterContext::instance()->node();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -758,6 +746,19 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
|
|||
} else if (cluster_ctx->node_role() == distributed::kEnvRoleOfScheduler) {
|
||||
return PSchedulerPipeline(resource);
|
||||
}
|
||||
} else {
|
||||
const std::string &server_mode = ps::PSContext::instance()->server_mode();
|
||||
if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) &&
|
||||
ps::PSContext::instance()->is_server()) {
|
||||
return ServerPipeline(resource);
|
||||
}
|
||||
if (ps::PSContext::instance()->is_server()) {
|
||||
resource->SetResult(kBackend, compile::CreateBackend());
|
||||
return PServerPipeline(resource);
|
||||
}
|
||||
if (ps::PSContext::instance()->is_scheduler()) {
|
||||
return PSchedulerPipeline(resource);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
@ -583,5 +583,9 @@ void PSContext::set_continuous_failure_times(uint32_t continuous_failure_times)
|
|||
}
|
||||
|
||||
uint32_t PSContext::continuous_failure_times() { return continuous_failure_times_; }
|
||||
|
||||
bool PSContext::enable_distributed_mindrt() const {
|
||||
return distributed::cluster::ClusterContext::instance()->initialized();
|
||||
}
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -253,6 +253,9 @@ class BACKEND_EXPORT PSContext {
|
|||
void set_continuous_failure_times(uint32_t continuous_failure_times);
|
||||
uint32_t continuous_failure_times();
|
||||
|
||||
// Whether distributed MindRT is enabled.
|
||||
bool enable_distributed_mindrt() const;
|
||||
|
||||
private:
|
||||
PSContext()
|
||||
: ps_enabled_(false),
|
||||
|
|
|
@ -25,6 +25,12 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
RecvActor::~RecvActor() {
|
||||
if (server_) {
|
||||
server_->Finalize();
|
||||
}
|
||||
}
|
||||
|
||||
void RecvActor::SetOpcontext(OpContext<DeviceTensor> *const op_context) {
|
||||
std::unique_lock<std::mutex> lock(context_mtx_);
|
||||
MS_EXCEPTION_IF_NULL(op_context);
|
||||
|
|
|
@ -36,8 +36,11 @@ class RecvActor : public RpcActor {
|
|||
const std::set<size_t> &modifiable_ref_output_indexes)
|
||||
: RpcActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
|
||||
modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kRecvActor),
|
||||
ip_(""),
|
||||
port_(0),
|
||||
server_(nullptr),
|
||||
is_context_valid_(false) {}
|
||||
~RecvActor() override = default;
|
||||
~RecvActor() override;
|
||||
|
||||
// Besides set the op context, this method also notify the message handler to 'RunOpInterProcessData'.
|
||||
void SetOpcontext(OpContext<DeviceTensor> *const op_context) override;
|
||||
|
|
|
@ -20,6 +20,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
SendActor::~SendActor() {
|
||||
if (client_) {
|
||||
client_->Disconnect(server_url_);
|
||||
client_->Finalize();
|
||||
}
|
||||
}
|
||||
|
||||
void SendActor::SetRouteInfo(uint32_t, const std::string &, const std::string &send_src_node_name,
|
||||
const std::string &send_dst_node_name) {
|
||||
auto peer_actor_id = inter_process_edge_name_;
|
||||
|
@ -38,12 +45,12 @@ bool SendActor::ConnectServer() {
|
|||
MS_EXCEPTION_IF_NULL(actor_route_table_proxy_);
|
||||
auto peer_actor_address = actor_route_table_proxy_->LookupRoute(peer_actor_id);
|
||||
// If route is successfully looked up, peer_actor_address is not empty.
|
||||
std::string server_url = peer_actor_address.ip() + ":" + std::to_string(peer_actor_address.port());
|
||||
if (!client_->Connect(server_url)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to connect to server of actor " << peer_actor_id << ", server_url: " << server_url;
|
||||
server_url_ = peer_actor_address.ip() + ":" + std::to_string(peer_actor_address.port());
|
||||
if (!client_->Connect(server_url_)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to connect to server of actor " << peer_actor_id << ", server_url: " << server_url_;
|
||||
}
|
||||
MS_LOG(INFO) << "Successfully connect to server " << server_url << ", inter-process edge name: " << peer_actor_id;
|
||||
peer_actor_urls_[peer_actor_id] = server_url;
|
||||
MS_LOG(INFO) << "Successfully connect to server " << server_url_ << ", inter-process edge name: " << peer_actor_id;
|
||||
peer_actor_urls_[peer_actor_id] = server_url_;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -33,8 +33,10 @@ class SendActor : public RpcActor {
|
|||
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
|
||||
const std::set<size_t> &modifiable_ref_output_indexes)
|
||||
: RpcActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
|
||||
modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kSendActor) {}
|
||||
~SendActor() override = default;
|
||||
modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kSendActor),
|
||||
server_url_(""),
|
||||
client_(nullptr) {}
|
||||
~SendActor() override;
|
||||
|
||||
// Set send actor's destination peer info, in another word, send actor's output.
|
||||
void SetRouteInfo(uint32_t dst_rank, const std::string &dst_role, const std::string &send_src_node_name,
|
||||
|
@ -57,6 +59,9 @@ class SendActor : public RpcActor {
|
|||
std::vector<std::string> peer_actor_ids_;
|
||||
mindspore::HashMap<std::string, std::string> peer_actor_urls_;
|
||||
|
||||
// The url of the peer recv actor's tcp server.
|
||||
std::string server_url_;
|
||||
|
||||
std::unique_ptr<TCPClient> client_;
|
||||
};
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ import mindspore._c_dataengine as cde
|
|||
from mindspore._c_expression import typing
|
||||
|
||||
from mindspore import log as logger
|
||||
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
|
||||
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched, _get_ps_context, _enable_distributed_mindrt
|
||||
from mindspore.dataset.engine.offload import GetOffloadModel
|
||||
|
||||
import mindspore.dataset.transforms.c_transforms as c_transforms
|
||||
|
@ -1899,6 +1899,19 @@ class Dataset:
|
|||
def parse(self, children=None):
|
||||
raise NotImplementedError("Dataset has to implement parse method.")
|
||||
|
||||
@staticmethod
|
||||
def _update_data_shard(num_shards, shard_id):
|
||||
"""
|
||||
Update the shard number and shard id if necessary.
|
||||
This is normally used in distributed training mode like Parameter Server training.
|
||||
"""
|
||||
# If this is in distributed execution mode,
|
||||
# the shard number and shard id might need to be updated according to the process's rank or role.
|
||||
if _enable_distributed_mindrt() and _is_role_pserver():
|
||||
num_shards = _get_ps_context("worker_num")
|
||||
shard_id = 0
|
||||
return num_shards, shard_id
|
||||
|
||||
def post_parse(self, ir_node):
|
||||
if self.cache:
|
||||
ir_node = ir_node.set_cache_client(self.cache.cache_client)
|
||||
|
@ -2163,6 +2176,7 @@ class MappableDataset(SourceDataset):
|
|||
|
||||
def __init__(self, num_parallel_workers=None, sampler=None, num_samples=None, shuffle=None, num_shards=None,
|
||||
shard_id=None, cache=None):
|
||||
num_shards, shard_id = self._update_data_shard(num_shards, shard_id)
|
||||
super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle,
|
||||
num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||
self.shuffle_flag = replace_none(shuffle, True)
|
||||
|
|
|
@ -302,6 +302,14 @@ def _is_fl_mode():
|
|||
return _get_ps_context("server_mode") in ("FEDERATED_LEARNING", "HYBRID_TRAINING")
|
||||
|
||||
|
||||
def _enable_distributed_mindrt():
|
||||
'''
|
||||
Whether the distributed MindRT is enabled.
|
||||
This method is used to distinguish from old distributed training mode.
|
||||
'''
|
||||
return ps_context().enable_distributed_mindrt()
|
||||
|
||||
|
||||
def _check_value(key, value):
|
||||
"""
|
||||
Validate the value for parameter server context keys.
|
||||
|
|
Loading…
Reference in New Issue