!30746 Fix recv actor stuck issue.

Merge pull request !30746 from ZPaC/sync-route-table
This commit is contained in:
i-robot 2022-03-04 03:01:03 +00:00 committed by Gitee
commit fdf7aebd78
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 105 additions and 37 deletions

View File

@ -253,12 +253,12 @@ int Connection::ReceiveMessage() {
return 0;
}
std::unique_ptr<MessageBase> msg(recv_message);
std::shared_ptr<MessageBase> msg(recv_message);
recv_message = nullptr;
// Call msg handler if set
if (message_handler) {
message_handler(std::move(msg));
message_handler(msg);
} else {
MS_LOG(INFO) << "Message handler was not found";
}

View File

@ -29,7 +29,7 @@
namespace mindspore {
namespace distributed {
namespace rpc {
using MessageHandler = std::function<void(std::unique_ptr<MessageBase> &&msg)>;
using MessageHandler = std::function<void(const std::shared_ptr<MessageBase> &)>;
using DeleteCallBack = void (*)(const std::string &from, const std::string &to);
using ConnectionCallBack = void (*)(void *conn);

View File

@ -37,7 +37,7 @@ class RpcKernelMod : public NativeCpuKernelMod {
void InitKernel(const CNodePtr &kernel_node) override { return; }
// Set remote data as input.
void SetRemoteInput(std::unique_ptr<MessageBase> &&) {}
void SetRemoteInput(const std::shared_ptr<MessageBase> &msg) {}
private:
};

View File

@ -16,31 +16,24 @@
#include "runtime/graph_scheduler/actor/rpc/recv_actor.h"
#include <memory>
#include <utility>
#include <functional>
#include <condition_variable>
#include "plugin/device/cpu/kernel/rpc/rpc_recv_kernel.h"
namespace mindspore {
namespace runtime {
void RecvActor::RunOpInterProcessData(std::unique_ptr<MessageBase> &&msg, OpContext<DeviceTensor> *const context) {
MS_ERROR_IF_NULL_WO_RET_VAL(msg);
MS_ERROR_IF_NULL_WO_RET_VAL(op_context_);
auto &sequential_num = context->sequential_num_;
(void)input_op_inter_process_[sequential_num].emplace_back(msg->From().Name());
void RecvActor::SetOpcontext(OpContext<DeviceTensor> *const op_context) {
std::unique_lock<std::mutex> lock(context_mtx_);
op_context_ = op_context;
is_context_valid_ = true;
context_cv_.notify_all();
}
auto is_run = CheckRunningCondition(context);
MS_LOG(INFO) << "Actor(" << GetAID().Name() << ") receive the input op inter-process. Edge is "
<< inter_process_edge_name_ << ". Check running condition:" << is_run;
// Parse the message from remote peer and set to rpc recv kernel.
auto recv_kernel_mod = dynamic_cast<kernel::RpcKernelMod *>(kernel_info_->MutableKernelMod());
MS_ERROR_IF_NULL_WO_RET_VAL(recv_kernel_mod);
// We set remote data by the interface of the rpc kernel, because currently there's no remote input for a kernel mod.
recv_kernel_mod->SetRemoteInput(std::move(msg));
if (is_run) {
Run(context);
}
return;
void RecvActor::ResetOpcontext() {
std::unique_lock<std::mutex> lock(context_mtx_);
is_context_valid_ = false;
}
void RecvActor::SetRouteInfo(uint32_t, const std::string &, const std::string &recv_src_node_name,
@ -59,6 +52,7 @@ bool RecvActor::StartServer() {
}
// Step 2: Set the message handler of the server.
server_->SetMessageHandler(std::bind(&RecvActor::HandleMessage, this, std::placeholders::_1));
// Step 2: Register the server address to route table. The server should not be connected before this step is done.
ActorAddress recv_actor_addresss;
@ -73,6 +67,28 @@ bool RecvActor::StartServer() {
return true;
}
void RecvActor::RunOpInterProcessData(const std::shared_ptr<MessageBase> &msg, OpContext<DeviceTensor> *const context) {
MS_ERROR_IF_NULL_WO_RET_VAL(msg);
MS_ERROR_IF_NULL_WO_RET_VAL(op_context_);
auto &sequential_num = context->sequential_num_;
(void)input_op_inter_process_[sequential_num].emplace_back(msg->From().Name());
auto is_run = CheckRunningCondition(context);
MS_LOG(INFO) << "Actor(" << GetAID().Name() << ") receive the input op inter-process. Edge is "
<< inter_process_edge_name_ << ". Check running condition:" << is_run;
// Parse the message from remote peer and set to rpc recv kernel.
auto recv_kernel_mod = dynamic_cast<kernel::RpcKernelMod *>(kernel_info_->MutableKernelMod());
MS_ERROR_IF_NULL_WO_RET_VAL(recv_kernel_mod);
// We set remote data by the interface of the rpc kernel, because currently there's no remote input for a kernel mod.
recv_kernel_mod->SetRemoteInput(msg);
if (is_run) {
Run(context);
}
return;
}
bool RecvActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
MS_EXCEPTION_IF_NULL(context);
// Step 1: Judge data and control inputs are satisfied.
@ -100,10 +116,15 @@ bool RecvActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) co
return true;
}
void RecvActor::HandleMessage(std::unique_ptr<MessageBase> &&msg) {
void RecvActor::HandleMessage(const std::shared_ptr<MessageBase> &msg) {
// Block the message handler if the context is invalid.
std::unique_lock<std::mutex> lock(context_mtx_);
context_cv_.wait(lock, [this] { return is_context_valid_; });
lock.unlock();
MS_ERROR_IF_NULL_WO_RET_VAL(msg);
MS_ERROR_IF_NULL_WO_RET_VAL(op_context_);
RunOpInterProcessData(std::move(msg), op_context_);
ActorDispatcher::Send(GetAID(), &RecvActor::RunOpInterProcessData, msg, op_context_);
}
} // namespace runtime
} // namespace mindspore

View File

@ -18,9 +18,11 @@
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_RPC_RECV_ACTOR_H_
#include <set>
#include <mutex>
#include <vector>
#include <string>
#include <memory>
#include <condition_variable>
#include "runtime/graph_scheduler/actor/rpc/rpc_actor.h"
namespace mindspore {
@ -36,8 +38,12 @@ class RecvActor : public RpcActor {
modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kRecvActor) {}
~RecvActor() override = default;
// When an inter-process data received, this method is called.
void RunOpInterProcessData(std::unique_ptr<MessageBase> &&msg, OpContext<DeviceTensor> *const context);
// Besides set the op context, this method also notify the message handler to 'RunOpInterProcessData'.
void SetOpcontext(OpContext<DeviceTensor> *const op_context) override;
// This method means the op context is invalid now. If the message handler is called while the op context is invalid,
// it should be blocked until 'SetOpcontext' is called.
void ResetOpcontext() override;
// Set recv actor's source peer info, in another word, recv actor's input.
void SetRouteInfo(uint32_t src_rank, const std::string &src_role, const std::string &recv_src_node_name,
@ -47,20 +53,27 @@ class RecvActor : public RpcActor {
bool StartServer();
protected:
// When an inter-process data received, this method is called.
void RunOpInterProcessData(const std::shared_ptr<MessageBase> &msg, OpContext<DeviceTensor> *const context);
// Besides the checking method in base class AbstractActor, condition of inter-process arrows should be checked for
// recv actor.
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override;
private:
void HandleMessage(std::unique_ptr<MessageBase> &&msg);
friend class GraphScheduler;
// The message callback of the tcp server.
void HandleMessage(const std::shared_ptr<MessageBase> &msg);
// The network address of this recv actor. It's generated automatically by rpc module.
std::string ip_;
uint16_t port_;
std::unique_ptr<TCPServer> server_;
// The variables used to ensure thread-safe of op context visited by recv actor.
bool is_context_valid_;
std::mutex context_mtx_;
std::condition_variable context_cv_;
};
using RecvActorPtr = std::shared_ptr<RecvActor>;

View File

@ -56,7 +56,11 @@ class RpcActor : public KernelActor {
// Normally, an actor's op_context is passed by its input actor, but rpc actors could be triggered by inter-process
// arrows which do not contain op_context. So we need to set op_context manually.
void SetOpcontext(OpContext<DeviceTensor> *const op_context);
virtual void SetOpcontext(OpContext<DeviceTensor> *const op_context);
// Reset op context. Because op context is recreated for each each sinked loop, this method should be called after
// each sinked loop is done in case rpc actors visit the invalid op context.
virtual void ResetOpcontext() {}
// Set the actor route proxy for rpc actors.
void SetActorRouteRableProxy(const ActorRouteTableProxyPtr &proxy);
@ -69,7 +73,7 @@ class RpcActor : public KernelActor {
const std::string &dst_node_name) {}
protected:
// The op context to run rpc actor inter-process op.
// The op context to run rpc actor inter-process op. Set by method 'SetOpcontext'.
OpContext<DeviceTensor> *op_context_;
// The inter-process edge name. It is also used as the actor id for route. It's a string consists of source node name

View File

@ -422,8 +422,8 @@ void GraphScheduler::Run(ActorSet *const actor_set, const std::vector<DeviceCont
#ifdef ENABLE_RPC_ACTOR
// Set OpContext to rpc node scheduler.
MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
rpc_node_scheduler_->SetOpcontext(&op_context);
auto op_context_setter = std::make_shared<RpcActorOpContextSetter>(rpc_node_scheduler_.get(), &op_context);
MS_EXCEPTION_IF_NULL(op_context_setter);
#endif
if ((strategy == GraphExecutionStrategy::kStep) && IsSingleOpActorSet(actor_set)) {

View File

@ -130,6 +130,19 @@ void RpcNodeScheduler::SetOpcontext(OpContext<DeviceTensor> *const op_context) {
}
}
void RpcNodeScheduler::ResetOpcontext() {
MS_EXCEPTION_IF_NULL(rpc_actor_set_);
for (auto &recv_actor : rpc_actor_set_->recv_actors_) {
MS_EXCEPTION_IF_NULL(recv_actor);
recv_actor->ResetOpcontext();
}
for (auto &send_actor : rpc_actor_set_->send_actors_) {
MS_EXCEPTION_IF_NULL(send_actor);
send_actor->ResetOpcontext();
}
}
ActorRouteTableProxyPtr RpcNodeScheduler::CreateRouteTableProxy() {
ActorRouteTableProxyPtr actor_route_table_proxy;
if (!ClusterContext::instance()->IsScheduler()) {

View File

@ -52,15 +52,32 @@ class RpcNodeScheduler {
void InsertSendActor(const SendActorPtr &send_actor);
void InsertRecvActor(const RecvActorPtr &recv_actor);
// Set op_context to rpc actors.
// Set op context to rpc actors.
void SetOpcontext(OpContext<DeviceTensor> *const op_context);
// Reset op context for rpc actors.
void ResetOpcontext();
private:
// Create new route table proxy.
ActorRouteTableProxyPtr CreateRouteTableProxy();
RpcActorSetPtr rpc_actor_set_;
};
// The setter of op context for rpc actors.
class RpcActorOpContextSetter {
public:
explicit RpcActorOpContextSetter(RpcNodeScheduler *rpc_node_scheduler, OpContext<DeviceTensor> *const op_context)
: rpc_node_scheduler_(rpc_node_scheduler), op_context_(op_context) {
rpc_node_scheduler_->SetOpcontext(op_context_);
}
~RpcActorOpContextSetter() { rpc_node_scheduler_->ResetOpcontext(); }
private:
RpcNodeScheduler *rpc_node_scheduler_;
OpContext<DeviceTensor> *op_context_;
};
} // namespace runtime
} // namespace mindspore

View File

@ -152,7 +152,7 @@ TEST_F(TCPTest, SendOneMessage) {
bool ret = server->Initialize(server_url);
ASSERT_TRUE(ret);
server->SetMessageHandler([](std::unique_ptr<MessageBase> &&message) -> void { IncrDataMsgNum(1); });
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> void { IncrDataMsgNum(1); });
// Start the tcp client.
auto client_url = "127.0.0.1:1234";
@ -191,7 +191,7 @@ TEST_F(TCPTest, sendTwoMessages) {
bool ret = server->Initialize(server_url);
ASSERT_TRUE(ret);
server->SetMessageHandler([](std::unique_ptr<MessageBase> &&message) -> void { IncrDataMsgNum(1); });
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> void { IncrDataMsgNum(1); });
// Start the tcp client.
auto client_url = "127.0.0.1:1234";