forked from mindspore-Ecosystem/mindspore
Fix recv actor stuck issue.
This commit is contained in:
parent
5b1772d2b3
commit
4c74b600ea
|
@ -253,12 +253,12 @@ int Connection::ReceiveMessage() {
|
|||
return 0;
|
||||
}
|
||||
|
||||
std::unique_ptr<MessageBase> msg(recv_message);
|
||||
std::shared_ptr<MessageBase> msg(recv_message);
|
||||
recv_message = nullptr;
|
||||
|
||||
// Call msg handler if set
|
||||
if (message_handler) {
|
||||
message_handler(std::move(msg));
|
||||
message_handler(msg);
|
||||
} else {
|
||||
MS_LOG(INFO) << "Message handler was not found";
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
namespace mindspore {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
using MessageHandler = std::function<void(std::unique_ptr<MessageBase> &&msg)>;
|
||||
using MessageHandler = std::function<void(const std::shared_ptr<MessageBase> &)>;
|
||||
using DeleteCallBack = void (*)(const std::string &from, const std::string &to);
|
||||
using ConnectionCallBack = void (*)(void *conn);
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class RpcKernelMod : public NativeCpuKernelMod {
|
|||
void InitKernel(const CNodePtr &kernel_node) override { return; }
|
||||
|
||||
// Set remote data as input.
|
||||
void SetRemoteInput(std::unique_ptr<MessageBase> &&) {}
|
||||
void SetRemoteInput(const std::shared_ptr<MessageBase> &msg) {}
|
||||
|
||||
private:
|
||||
};
|
||||
|
|
|
@ -16,31 +16,24 @@
|
|||
|
||||
#include "runtime/graph_scheduler/actor/rpc/recv_actor.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <functional>
|
||||
#include <condition_variable>
|
||||
#include "plugin/device/cpu/kernel/rpc/rpc_recv_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
void RecvActor::RunOpInterProcessData(std::unique_ptr<MessageBase> &&msg, OpContext<DeviceTensor> *const context) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(msg);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(op_context_);
|
||||
auto &sequential_num = context->sequential_num_;
|
||||
(void)input_op_inter_process_[sequential_num].emplace_back(msg->From().Name());
|
||||
void RecvActor::SetOpcontext(OpContext<DeviceTensor> *const op_context) {
|
||||
std::unique_lock<std::mutex> lock(context_mtx_);
|
||||
op_context_ = op_context;
|
||||
is_context_valid_ = true;
|
||||
context_cv_.notify_all();
|
||||
}
|
||||
|
||||
auto is_run = CheckRunningCondition(context);
|
||||
MS_LOG(INFO) << "Actor(" << GetAID().Name() << ") receive the input op inter-process. Edge is "
|
||||
<< inter_process_edge_name_ << ". Check running condition:" << is_run;
|
||||
|
||||
// Parse the message from remote peer and set to rpc recv kernel.
|
||||
auto recv_kernel_mod = dynamic_cast<kernel::RpcKernelMod *>(kernel_info_->MutableKernelMod());
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(recv_kernel_mod);
|
||||
// We set remote data by the interface of the rpc kernel, because currently there's no remote input for a kernel mod.
|
||||
recv_kernel_mod->SetRemoteInput(std::move(msg));
|
||||
|
||||
if (is_run) {
|
||||
Run(context);
|
||||
}
|
||||
return;
|
||||
void RecvActor::ResetOpcontext() {
|
||||
std::unique_lock<std::mutex> lock(context_mtx_);
|
||||
is_context_valid_ = false;
|
||||
}
|
||||
|
||||
void RecvActor::SetRouteInfo(uint32_t, const std::string &, const std::string &recv_src_node_name,
|
||||
|
@ -59,6 +52,7 @@ bool RecvActor::StartServer() {
|
|||
}
|
||||
|
||||
// Step 2: Set the message handler of the server.
|
||||
server_->SetMessageHandler(std::bind(&RecvActor::HandleMessage, this, std::placeholders::_1));
|
||||
|
||||
// Step 2: Register the server address to route table. The server should not be connected before this step is done.
|
||||
ActorAddress recv_actor_addresss;
|
||||
|
@ -73,6 +67,28 @@ bool RecvActor::StartServer() {
|
|||
return true;
|
||||
}
|
||||
|
||||
void RecvActor::RunOpInterProcessData(const std::shared_ptr<MessageBase> &msg, OpContext<DeviceTensor> *const context) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(msg);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(op_context_);
|
||||
auto &sequential_num = context->sequential_num_;
|
||||
(void)input_op_inter_process_[sequential_num].emplace_back(msg->From().Name());
|
||||
|
||||
auto is_run = CheckRunningCondition(context);
|
||||
MS_LOG(INFO) << "Actor(" << GetAID().Name() << ") receive the input op inter-process. Edge is "
|
||||
<< inter_process_edge_name_ << ". Check running condition:" << is_run;
|
||||
|
||||
// Parse the message from remote peer and set to rpc recv kernel.
|
||||
auto recv_kernel_mod = dynamic_cast<kernel::RpcKernelMod *>(kernel_info_->MutableKernelMod());
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(recv_kernel_mod);
|
||||
// We set remote data by the interface of the rpc kernel, because currently there's no remote input for a kernel mod.
|
||||
recv_kernel_mod->SetRemoteInput(msg);
|
||||
|
||||
if (is_run) {
|
||||
Run(context);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
bool RecvActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
// Step 1: Judge data and control inputs are satisfied.
|
||||
|
@ -100,10 +116,15 @@ bool RecvActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) co
|
|||
return true;
|
||||
}
|
||||
|
||||
void RecvActor::HandleMessage(std::unique_ptr<MessageBase> &&msg) {
|
||||
void RecvActor::HandleMessage(const std::shared_ptr<MessageBase> &msg) {
|
||||
// Block the message handler if the context is invalid.
|
||||
std::unique_lock<std::mutex> lock(context_mtx_);
|
||||
context_cv_.wait(lock, [this] { return is_context_valid_; });
|
||||
lock.unlock();
|
||||
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(msg);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(op_context_);
|
||||
RunOpInterProcessData(std::move(msg), op_context_);
|
||||
ActorDispatcher::Send(GetAID(), &RecvActor::RunOpInterProcessData, msg, op_context_);
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,9 +18,11 @@
|
|||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_RPC_RECV_ACTOR_H_
|
||||
|
||||
#include <set>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <condition_variable>
|
||||
#include "runtime/graph_scheduler/actor/rpc/rpc_actor.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -36,8 +38,12 @@ class RecvActor : public RpcActor {
|
|||
modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kRecvActor) {}
|
||||
~RecvActor() override = default;
|
||||
|
||||
// When an inter-process data received, this method is called.
|
||||
void RunOpInterProcessData(std::unique_ptr<MessageBase> &&msg, OpContext<DeviceTensor> *const context);
|
||||
// Besides set the op context, this method also notify the message handler to 'RunOpInterProcessData'.
|
||||
void SetOpcontext(OpContext<DeviceTensor> *const op_context) override;
|
||||
|
||||
// This method means the op context is invalid now. If the message handler is called while the op context is invalid,
|
||||
// it should be blocked until 'SetOpcontext' is called.
|
||||
void ResetOpcontext() override;
|
||||
|
||||
// Set recv actor's source peer info, in another word, recv actor's input.
|
||||
void SetRouteInfo(uint32_t src_rank, const std::string &src_role, const std::string &recv_src_node_name,
|
||||
|
@ -47,20 +53,27 @@ class RecvActor : public RpcActor {
|
|||
bool StartServer();
|
||||
|
||||
protected:
|
||||
// When an inter-process data received, this method is called.
|
||||
void RunOpInterProcessData(const std::shared_ptr<MessageBase> &msg, OpContext<DeviceTensor> *const context);
|
||||
|
||||
// Besides the checking method in base class AbstractActor, condition of inter-process arrows should be checked for
|
||||
// recv actor.
|
||||
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override;
|
||||
|
||||
private:
|
||||
void HandleMessage(std::unique_ptr<MessageBase> &&msg);
|
||||
|
||||
friend class GraphScheduler;
|
||||
// The message callback of the tcp server.
|
||||
void HandleMessage(const std::shared_ptr<MessageBase> &msg);
|
||||
|
||||
// The network address of this recv actor. It's generated automatically by rpc module.
|
||||
std::string ip_;
|
||||
uint16_t port_;
|
||||
|
||||
std::unique_ptr<TCPServer> server_;
|
||||
|
||||
// The variables used to ensure thread-safe of op context visited by recv actor.
|
||||
bool is_context_valid_;
|
||||
std::mutex context_mtx_;
|
||||
std::condition_variable context_cv_;
|
||||
};
|
||||
|
||||
using RecvActorPtr = std::shared_ptr<RecvActor>;
|
||||
|
|
|
@ -56,7 +56,11 @@ class RpcActor : public KernelActor {
|
|||
|
||||
// Normally, an actor's op_context is passed by its input actor, but rpc actors could be triggered by inter-process
|
||||
// arrows which do not contain op_context. So we need to set op_context manually.
|
||||
void SetOpcontext(OpContext<DeviceTensor> *const op_context);
|
||||
virtual void SetOpcontext(OpContext<DeviceTensor> *const op_context);
|
||||
|
||||
// Reset op context. Because op context is recreated for each each sinked loop, this method should be called after
|
||||
// each sinked loop is done in case rpc actors visit the invalid op context.
|
||||
virtual void ResetOpcontext() {}
|
||||
|
||||
// Set the actor route proxy for rpc actors.
|
||||
void SetActorRouteRableProxy(const ActorRouteTableProxyPtr &proxy);
|
||||
|
@ -69,7 +73,7 @@ class RpcActor : public KernelActor {
|
|||
const std::string &dst_node_name) {}
|
||||
|
||||
protected:
|
||||
// The op context to run rpc actor inter-process op.
|
||||
// The op context to run rpc actor inter-process op. Set by method 'SetOpcontext'.
|
||||
OpContext<DeviceTensor> *op_context_;
|
||||
|
||||
// The inter-process edge name. It is also used as the actor id for route. It's a string consists of source node name
|
||||
|
|
|
@ -422,8 +422,8 @@ void GraphScheduler::Run(ActorSet *const actor_set, const std::vector<DeviceCont
|
|||
|
||||
#ifdef ENABLE_RPC_ACTOR
|
||||
// Set OpContext to rpc node scheduler.
|
||||
MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
|
||||
rpc_node_scheduler_->SetOpcontext(&op_context);
|
||||
auto op_context_setter = std::make_shared<RpcActorOpContextSetter>(rpc_node_scheduler_.get(), &op_context);
|
||||
MS_EXCEPTION_IF_NULL(op_context_setter);
|
||||
#endif
|
||||
|
||||
if ((strategy == GraphExecutionStrategy::kStep) && IsSingleOpActorSet(actor_set)) {
|
||||
|
|
|
@ -130,6 +130,19 @@ void RpcNodeScheduler::SetOpcontext(OpContext<DeviceTensor> *const op_context) {
|
|||
}
|
||||
}
|
||||
|
||||
void RpcNodeScheduler::ResetOpcontext() {
|
||||
MS_EXCEPTION_IF_NULL(rpc_actor_set_);
|
||||
|
||||
for (auto &recv_actor : rpc_actor_set_->recv_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(recv_actor);
|
||||
recv_actor->ResetOpcontext();
|
||||
}
|
||||
for (auto &send_actor : rpc_actor_set_->send_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(send_actor);
|
||||
send_actor->ResetOpcontext();
|
||||
}
|
||||
}
|
||||
|
||||
ActorRouteTableProxyPtr RpcNodeScheduler::CreateRouteTableProxy() {
|
||||
ActorRouteTableProxyPtr actor_route_table_proxy;
|
||||
if (!ClusterContext::instance()->IsScheduler()) {
|
||||
|
|
|
@ -52,15 +52,32 @@ class RpcNodeScheduler {
|
|||
void InsertSendActor(const SendActorPtr &send_actor);
|
||||
void InsertRecvActor(const RecvActorPtr &recv_actor);
|
||||
|
||||
// Set op_context to rpc actors.
|
||||
// Set op context to rpc actors.
|
||||
void SetOpcontext(OpContext<DeviceTensor> *const op_context);
|
||||
|
||||
// Reset op context for rpc actors.
|
||||
void ResetOpcontext();
|
||||
|
||||
private:
|
||||
// Create new route table proxy.
|
||||
ActorRouteTableProxyPtr CreateRouteTableProxy();
|
||||
|
||||
RpcActorSetPtr rpc_actor_set_;
|
||||
};
|
||||
|
||||
// The setter of op context for rpc actors.
|
||||
class RpcActorOpContextSetter {
|
||||
public:
|
||||
explicit RpcActorOpContextSetter(RpcNodeScheduler *rpc_node_scheduler, OpContext<DeviceTensor> *const op_context)
|
||||
: rpc_node_scheduler_(rpc_node_scheduler), op_context_(op_context) {
|
||||
rpc_node_scheduler_->SetOpcontext(op_context_);
|
||||
}
|
||||
~RpcActorOpContextSetter() { rpc_node_scheduler_->ResetOpcontext(); }
|
||||
|
||||
private:
|
||||
RpcNodeScheduler *rpc_node_scheduler_;
|
||||
OpContext<DeviceTensor> *op_context_;
|
||||
};
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -152,7 +152,7 @@ TEST_F(TCPTest, SendOneMessage) {
|
|||
bool ret = server->Initialize(server_url);
|
||||
ASSERT_TRUE(ret);
|
||||
|
||||
server->SetMessageHandler([](std::unique_ptr<MessageBase> &&message) -> void { IncrDataMsgNum(1); });
|
||||
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> void { IncrDataMsgNum(1); });
|
||||
|
||||
// Start the tcp client.
|
||||
auto client_url = "127.0.0.1:1234";
|
||||
|
@ -191,7 +191,7 @@ TEST_F(TCPTest, sendTwoMessages) {
|
|||
bool ret = server->Initialize(server_url);
|
||||
ASSERT_TRUE(ret);
|
||||
|
||||
server->SetMessageHandler([](std::unique_ptr<MessageBase> &&message) -> void { IncrDataMsgNum(1); });
|
||||
server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> void { IncrDataMsgNum(1); });
|
||||
|
||||
// Start the tcp client.
|
||||
auto client_url = "127.0.0.1:1234";
|
||||
|
|
Loading…
Reference in New Issue