!20841 [MS][LITE]implement shared thread policy with hqueue in mindrt

Merge pull request !20841 from zhaizhiqiang/master
This commit is contained in:
i-robot 2021-10-07 06:48:45 +00:00 committed by Gitee
commit a73befd070
20 changed files with 350 additions and 384 deletions

View File

@ -60,7 +60,7 @@ class SwitchActor : public AbstractActor {
std::vector<KernelWithIndex> formal_parameters_;
// Input data.
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, std::vector<KernelWithIndex>>> input_nodes_;
std::unordered_map<int, std::unordered_map<size_t, std::vector<KernelWithIndex>>> input_nodes_;
// The store node records the value node input of the switch actor.
std::vector<std::pair<size_t, AnfNodePtr>> store_nodes_;

View File

@ -173,8 +173,7 @@ void MemoryManagerActor::SetOpContextMemoryAllocFail(const std::string &kernel_n
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(op_context);
MS_EXCEPTION_IF_NULL(op_context->sequential_num_);
auto step_id = uuids::uuid::ToBytes(*(op_context->sequential_num_));
int step_id = op_context->sequential_num_;
// First occur allocating memory failed.
if (mem_alloc_failed_step_ids_.find(step_id) == mem_alloc_failed_step_ids_.end()) {
mem_alloc_failed_step_ids_.clear();

View File

@ -70,7 +70,7 @@ class MemoryManagerActor : public ActorBase {
// MemoryManagerActor object is used like a single instance, if one actor allocates memory failed in one batch, which
// will set fail message info OpContext, major thread will destroy the OpContext object, subsequent actor can not set
// fail message again, so we record allocating memory fail event by the uuid of the batch, which is key of the set.
std::set<std::string> mem_alloc_failed_step_ids_;
std::set<int> mem_alloc_failed_step_ids_;
};
} // namespace runtime
} // namespace mindspore

View File

@ -328,9 +328,8 @@ void GraphScheduler::Run(const ActorSet *actor_set, const std::vector<std::vecto
// Construct OpContext.
OpContext<DeviceTensor> op_context;
uuids::uuid sequential_num;
std::vector<Promise<int>> result(1);
op_context.sequential_num_ = &sequential_num;
op_context.sequential_num_ = RandInt::Instance().Get();
op_context.results_ = &result;
if ((strategy == GraphExecutionStrategy::kStep) && IsSingleOpActorSet(actor_set)) {

View File

@ -24,14 +24,13 @@
#include <string>
#include <utility>
#include "thread/hqueue.h"
#include "actor/msg.h"
#include "actor/mailbox.h"
namespace mindspore {
class ActorBase;
class ActorMgr;
class ActorPolicy;
class ActorWorker;
class ActorThreadPool;
@ -40,7 +39,7 @@ using ActorReference = std::shared_ptr<ActorBase>;
// should be at least greater than 1
constexpr uint32_t MAX_ACTOR_RECORD_SIZE = 3;
class ActorBase : std::enable_shared_from_this<ActorBase> {
class ActorBase {
public:
inline const AID &GetAID() const { return id; }
@ -58,6 +57,7 @@ class ActorBase : std::enable_shared_from_this<ActorBase> {
startPoint = (startPoint + MAX_ACTOR_RECORD_SIZE - 1) % MAX_ACTOR_RECORD_SIZE;
}
}
ActorBase();
explicit ActorBase(const std::string &name);
explicit ActorBase(const std::string &name, ActorThreadPool *pool);
@ -97,12 +97,12 @@ class ActorBase : std::enable_shared_from_this<ActorBase> {
virtual void Finalize() {}
// KHTTPMsg handler
virtual void HandleHttp(std::unique_ptr<MessageBase> msg) {
virtual void HandleHttp(const std::unique_ptr<MessageBase> &msg) {
MS_LOG(ERROR) << "ACTOR (" << id.Name().c_str() << ") HandleHttp() is not implemented";
}
// KLOCALMsg handler
virtual void HandleLocalMsg(std::unique_ptr<MessageBase> msg) {
virtual void HandleLocalMsg(const std::unique_ptr<MessageBase> &msg) {
MS_LOG(ERROR) << "ACTOR (" << id.Name().c_str() << ") HandleLocalMsg() is not implemented.";
}
@ -200,12 +200,12 @@ class ActorBase : std::enable_shared_from_this<ActorBase> {
void Run();
void Quit();
int EnqueMessage(std::unique_ptr<MessageBase> &&msg);
int EnqueMessage(std::unique_ptr<MessageBase> msg);
void Spawn(const std::shared_ptr<ActorBase> &actor, std::unique_ptr<ActorPolicy> actorThread);
void SetRunningStatus(bool start);
void Spawn(const std::shared_ptr<ActorBase> &actor, std::unique_ptr<MailBox> mailbox);
std::unique_ptr<ActorPolicy> actorPolicy;
std::unique_ptr<MailBox> mailbox;
std::atomic_bool terminating_ = false;
AID id;
std::map<std::string, ActorFunction> actionFunctions;

View File

@ -27,6 +27,7 @@ constexpr int ERRORCODE_SUCCESS = 1;
constexpr int ACTOR_PARAMER_ERR = -101;
constexpr int ACTOR_NOT_FIND = -102;
constexpr int IO_NOT_FIND = -103;
constexpr int ACTOR_TERMINATED = -104;
// TCP module err code -301 ~ -400
// Null

View File

@ -48,6 +48,18 @@ struct OpData {
int index_;
};
class RandInt {
public:
int Get() { return rand(); }
static RandInt &Instance() {
static RandInt instance;
return instance;
}
private:
RandInt() { srand(time(NULL)); }
};
template <typename T>
using OpDataPtr = std::shared_ptr<OpData<T>>;
@ -57,7 +69,7 @@ using OpDataUniquePtr = std::unique_ptr<OpData<T>>;
// The context of opActor running.
template <typename T>
struct OpContext {
uuids::uuid *sequential_num_;
int sequential_num_;
std::vector<OpDataPtr<T>> *output_data_;
std::vector<Promise<int>> *results_;
// Record the error info for print.
@ -100,11 +112,11 @@ class OpActor : public ActorBase {
protected:
// The op data.
std::unordered_map<uuids::uuid *, std::vector<OpData<T> *>> input_op_datas_;
std::unordered_map<int, std::vector<OpData<T> *>> input_op_datas_;
std::vector<DataArrowPtr> output_data_arrows_;
// The op controls.
std::unordered_map<uuids::uuid *, std::vector<AID *>> input_op_controls_;
std::unordered_map<int, std::vector<AID *>> input_op_controls_;
std::vector<AID> output_control_arrows_;
};
@ -128,8 +140,7 @@ int MindrtRun(const std::vector<OpDataPtr<T>> &input_data, std::vector<OpDataPtr
const void *kernel_call_back_before, const void *kernel_call_back_after) {
OpContext<T> context;
std::vector<Promise<int>> promises(output_data->size());
uuids::uuid uid;
context.sequential_num_ = &uid;
context.sequential_num_ = RandInt::Instance().Get();
context.results_ = &promises;
context.output_data_ = output_data;
context.kernel_call_back_before_ = kernel_call_back_before;

View File

@ -33,8 +33,8 @@ using MessageHandler = std::function<void(ActorBase *)>;
class MessageAsync : public MessageBase {
public:
explicit MessageAsync(MessageHandler &h) : MessageBase("Async", Type::KASYNC), handler(h) {}
~MessageAsync() override {}
explicit MessageAsync(MessageHandler &&h) : MessageBase("Async", Type::KASYNC), handler(h) {}
virtual ~MessageAsync() = default;
void Run(ActorBase *actor) override { (handler)(actor); }
private:
@ -52,7 +52,7 @@ struct AsyncHelper<void> {
template <typename F>
void operator()(const AID &aid, F &&f) {
std::function<void(ActorBase *)> handler = [=](ActorBase *) { f(); };
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
MINDRT_OOM_EXIT(msg);
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
}
@ -68,7 +68,7 @@ struct AsyncHelper<Future<R>> {
MessageHandler handler = [=](ActorBase *) { promise->Associate(f()); };
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
MINDRT_OOM_EXIT(msg);
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
return future;
@ -84,7 +84,7 @@ struct AsyncHelper {
Future<R> future = promise->GetFuture();
std::function<void(ActorBase *)> handler = [=](ActorBase *) { promise->SetValue(f()); };
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
MINDRT_OOM_EXIT(msg);
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
return future;
@ -102,7 +102,7 @@ void Async(const AID &aid, void (T::*method)()) {
MINDRT_ASSERT(t != nullptr);
(t->*method)();
};
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
MINDRT_OOM_EXIT(msg);
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
}
@ -115,7 +115,7 @@ void Async(const AID &aid, void (T::*method)(Arg0), Arg1 &&arg) {
MINDRT_ASSERT(t != nullptr);
(t->*method)(arg);
};
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
MINDRT_OOM_EXIT(msg);
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
}
@ -128,7 +128,7 @@ void Async(const AID &aid, void (T::*method)(Args0...), std::tuple<Args1...> &&t
MINDRT_ASSERT(t != nullptr);
Apply(t, method, tuple);
};
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
MINDRT_OOM_EXIT(msg);
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
}
@ -152,7 +152,7 @@ Future<R> Async(const AID &aid, Future<R> (T::*method)()) {
MINDRT_ASSERT(t != nullptr);
promise->Associate((t->*method)());
};
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
MINDRT_OOM_EXIT(msg);
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
return future;
@ -171,7 +171,7 @@ Future<R> Async(const AID &aid, Future<R> (T::*method)(Arg0), Arg1 &&arg) {
promise->Associate((t->*method)(arg));
};
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
MINDRT_OOM_EXIT(msg);
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
return future;
@ -190,7 +190,7 @@ Future<R> Async(const AID &aid, Future<R> (T::*method)(Args0...), std::tuple<Arg
promise->Associate(Apply(t, method, tuple));
};
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
MINDRT_OOM_EXIT(msg);
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
return future;
@ -217,7 +217,7 @@ Future<R> Async(const AID &aid, R (T::*method)()) {
promise->SetValue((t->*method)());
};
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
MINDRT_OOM_EXIT(msg);
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
return future;
@ -237,7 +237,7 @@ Future<R> Async(const AID &aid, R (T::*method)(Arg0), Arg1 &&arg) {
MINDRT_ASSERT(t != nullptr);
promise->SetValue((t->*method)(arg));
};
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
MINDRT_OOM_EXIT(msg);
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
return future;
@ -257,7 +257,7 @@ Future<R> Async(const AID &aid, R (T::*method)(Args0...), std::tuple<Args1...> &
MINDRT_ASSERT(t != nullptr);
promise->SetValue(Apply(t, method, tuple));
};
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
MINDRT_OOM_EXIT(msg);
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
return future;

View File

@ -33,13 +33,13 @@ int Initialize(const std::string &tcpUrl, const std::string &tcpUrlAdv = "", con
const std::string &udpUrlAdv = "", int threadCount = 0);
// brief spawn a process to run an actor
AID Spawn(ActorReference actor, bool sharedThread = true, bool start = true);
AID Spawn(ActorReference actor, bool sharedThread = true);
// brief wait for the actor process to exit . It will be discarded
void Await(const ActorReference &actor);
// brief get actor with aid
ActorBase *GetActor(const AID &actor);
ActorReference GetActor(const AID &actor);
// brief wait for the actor process to exit
void Await(const AID &actor);
@ -47,9 +47,6 @@ void Await(const AID &actor);
// brief Terminate the actor to exit
void Terminate(const AID &actor);
// brief set the actor's running status
void SetActorStatus(const AID &actor, bool start);
// brief Terminate all actors
void TerminateAll();

View File

@ -16,28 +16,24 @@
#include "actor/actor.h"
#include "actor/actormgr.h"
#include "actor/actorpolicyinterface.h"
#include "actor/iomgr.h"
namespace mindspore {
ActorBase::ActorBase() : actorPolicy(nullptr), id("", ActorMgr::GetActorMgrRef()->GetUrl()), actionFunctions() {}
ActorBase::ActorBase() : mailbox(nullptr), id("", ActorMgr::GetActorMgrRef()->GetUrl()), actionFunctions() {}
ActorBase::ActorBase(const std::string &name)
: actorPolicy(nullptr), id(name, ActorMgr::GetActorMgrRef()->GetUrl()), actionFunctions() {}
: mailbox(nullptr), id(name, ActorMgr::GetActorMgrRef()->GetUrl()), actionFunctions() {}
ActorBase::ActorBase(const std::string &name, ActorThreadPool *pool)
: actorPolicy(nullptr), id(name, ActorMgr::GetActorMgrRef()->GetUrl()), actionFunctions(), pool_(pool) {}
: mailbox(nullptr), id(name, ActorMgr::GetActorMgrRef()->GetUrl()), actionFunctions(), pool_(pool) {}
ActorBase::~ActorBase() {}
void ActorBase::Spawn(const std::shared_ptr<ActorBase> &actor, std::unique_ptr<ActorPolicy> thread) {
// lock here or await(). and unlock at Quit() or at aweit.
void ActorBase::Spawn(const std::shared_ptr<ActorBase> &actor, std::unique_ptr<MailBox> mailboxPtr) {
// lock here or await(). and unlock at Quit() or at await.
waiterLock.lock();
actorPolicy = std::move(thread);
MINDRT_OOM_EXIT(actorPolicy);
this->mailbox = std::move(mailboxPtr);
}
void ActorBase::SetRunningStatus(bool start) { actorPolicy->SetRunningStatus(start); }
void ActorBase::Await() {
std::string actorName = id.Name();
@ -46,12 +42,19 @@ void ActorBase::Await() {
waiterLock.lock();
waiterLock.unlock();
// mailbox's hook may hold the actor reference, we need explicitly free the mailbox to avoid the memory leak. the
// details can refer to the comments in ActorMgr::Spawn
delete mailbox.release();
MS_LOG(DEBUG) << "ACTOR succeeded in waiting. a=" << actorName.c_str();
}
void ActorBase::Terminate() {
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageBase("Terminate", MessageBase::Type::KTERMINATE));
MINDRT_OOM_EXIT(msg);
(void)EnqueMessage(std::move(msg));
bool flag = false;
if (terminating_.compare_exchange_strong(flag, true)) {
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageBase("Terminate", MessageBase::Type::KTERMINATE));
MINDRT_OOM_EXIT(msg);
(void)EnqueMessage(std::move(msg));
}
}
void ActorBase::HandlekMsg(const std::unique_ptr<MessageBase> &msg) {
@ -64,61 +67,76 @@ void ActorBase::HandlekMsg(const std::unique_ptr<MessageBase> &msg) {
<< ",m=" << msg->Name().c_str();
}
}
int ActorBase::EnqueMessage(std::unique_ptr<MessageBase> &&msg) { return actorPolicy->EnqueMessage(std::move(msg)); }
int ActorBase::EnqueMessage(std::unique_ptr<MessageBase> msg) {
int ret = mailbox->EnqueueMessage(std::move(msg));
return ret;
}
void ActorBase::Quit() {
Finalize();
// lock at spawn(), unlock here.
waiterLock.unlock();
actorPolicy->Terminate(this);
}
void ActorBase::Run() {
for (;;) {
auto msgs = actorPolicy->GetMsgs();
if (msgs == nullptr) {
return;
}
for (auto it = msgs->begin(); it != msgs->end(); ++it) {
std::unique_ptr<MessageBase> &msg = *it;
if (msg == nullptr) {
continue;
auto msgHandler = [this](const std::unique_ptr<MessageBase> &msg) {
AddMsgRecord(msg->Name());
switch (msg->GetType()) {
case MessageBase::Type::KMSG:
case MessageBase::Type::KUDP: {
if (Filter(msg)) {
return ERRORCODE_SUCCESS;
}
this->HandlekMsg(msg);
return ERRORCODE_SUCCESS;
}
AddMsgRecord(msg->Name());
switch (msg->GetType()) {
case MessageBase::Type::KMSG:
case MessageBase::Type::KUDP: {
if (Filter(msg)) {
continue;
}
this->HandlekMsg(msg);
break;
case MessageBase::Type::KHTTP: {
this->HandleHttp(msg);
return ERRORCODE_SUCCESS;
}
case MessageBase::Type::KASYNC: {
msg->Run(this);
return ERRORCODE_SUCCESS;
}
case MessageBase::Type::KLOCAL: {
this->HandleLocalMsg(msg);
return ERRORCODE_SUCCESS;
}
case MessageBase::Type::KTERMINATE: {
this->Quit();
return ACTOR_TERMINATED;
}
case MessageBase::Type::KEXIT: {
this->Exited(msg->From());
return ERRORCODE_SUCCESS;
}
}
return ERRORCODE_SUCCESS;
};
if (this->mailbox->TakeAllMsgsEachTime()) {
while (auto msgs = mailbox->GetMsgs()) {
for (auto it = msgs->begin(); it != msgs->end(); ++it) {
std::unique_ptr<MessageBase> &msg = *it;
if (msg == nullptr) {
continue;
}
case MessageBase::Type::KHTTP: {
this->HandleHttp(std::move(msg));
break;
}
case MessageBase::Type::KASYNC: {
msg->Run(this);
break;
}
case MessageBase::Type::KLOCAL: {
this->HandleLocalMsg(std::move(msg));
break;
}
case MessageBase::Type::KTERMINATE: {
this->Quit();
MS_LOG_DEBUG << "dequeue message]actor=" << id.Name() << ",msg=" << msg->Name();
if (msgHandler(msg) == ACTOR_TERMINATED) {
return;
}
case MessageBase::Type::KEXIT: {
this->Exited(msg->From());
break;
}
}
msgs->clear();
}
} else {
while (auto msg = mailbox->GetMsg()) {
if (msgHandler(msg) == ACTOR_TERMINATED) {
return;
}
}
msgs->clear();
}
return;
}
int ActorBase::Send(const AID &to, std::unique_ptr<MessageBase> msg) {

View File

@ -20,7 +20,6 @@
#include <utility>
#include "actor/actormgr.h"
#include "actor/actorpolicy.h"
#include "actor/iomgr.h"
namespace mindspore {
@ -134,10 +133,7 @@ void ActorMgr::TerminateAll() {
// send terminal msg to all actors.
for (auto actorIt = actorsWaiting.begin(); actorIt != actorsWaiting.end(); ++actorIt) {
(*actorIt)->SetRunningStatus(true);
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageBase("Terminate", MessageBase::Type::KTERMINATE));
MINDRT_OOM_EXIT(msg);
(void)(*actorIt)->EnqueMessage(std::move(msg));
(*actorIt)->Terminate();
}
// wait actor's thread to finish.
@ -165,7 +161,7 @@ void ActorMgr::Finalize() {
MS_LOG(INFO) << "mindrt IOMGRS finish exiting.";
}
ActorBase *ActorMgr::GetActor(const AID &id) {
ActorReference ActorMgr::GetActor(const AID &id) {
#ifndef MS_COMPILE_IOS
actorsMutex.lock_shared();
#else
@ -179,7 +175,7 @@ ActorBase *ActorMgr::GetActor(const AID &id) {
#else
actorsMutex.unlock();
#endif
return result.get();
return result;
} else {
#ifndef MS_COMPILE_IOS
actorsMutex.unlock_shared();
@ -191,7 +187,11 @@ ActorBase *ActorMgr::GetActor(const AID &id) {
}
}
int ActorMgr::Send(const AID &to, std::unique_ptr<MessageBase> &&msg, bool remoteLink, bool isExactNotRemote) {
int ActorMgr::EnqueueMessage(mindspore::ActorReference actor, std::unique_ptr<mindspore::MessageBase> msg) {
return actor->EnqueMessage(std::move(msg));
}
int ActorMgr::Send(const AID &to, std::unique_ptr<MessageBase> msg, bool remoteLink, bool isExactNotRemote) {
// The destination is local
if (IsLocalAddres(to)) {
auto actor = GetActor(to);
@ -199,7 +199,7 @@ int ActorMgr::Send(const AID &to, std::unique_ptr<MessageBase> &&msg, bool remot
if (to.GetProtocol() == MINDRT_UDP && msg->GetType() == MessageBase::Type::KMSG) {
msg->type = MessageBase::Type::KUDP;
}
return actor->EnqueMessage(std::move(msg));
return EnqueueMessage(actor, std::move(msg));
} else {
return ACTOR_NOT_FIND;
}
@ -224,7 +224,7 @@ int ActorMgr::Send(const AID &to, std::unique_ptr<MessageBase> &&msg, bool remot
}
}
AID ActorMgr::Spawn(const ActorReference &actor, bool shareThread, bool start) {
AID ActorMgr::Spawn(const ActorReference &actor, bool shareThread) {
actorsMutex.lock();
if (actors.find(actor->GetAID().Name()) != actors.end()) {
actorsMutex.unlock();
@ -232,17 +232,20 @@ AID ActorMgr::Spawn(const ActorReference &actor, bool shareThread, bool start) {
MINDRT_EXIT("Actor name conflicts.");
}
MS_LOG(DEBUG) << "ACTOR was spawned,a=" << actor->GetAID().Name().c_str();
std::unique_ptr<ActorPolicy> threadPolicy;
if (shareThread) {
threadPolicy.reset(new (std::nothrow) ShardedThread(actor));
MINDRT_OOM_EXIT(threadPolicy);
actor->Spawn(actor, std::move(threadPolicy));
auto mailbox = std::unique_ptr<MailBox>(new (std::nothrow) NonblockingMailBox());
auto hook = std::unique_ptr<std::function<void()>>(
new std::function<void()>([actor]() { ActorMgr::GetActorMgrRef()->SetActorReady(actor); }));
// the mailbox has this hook, the hook holds the actor reference, the actor has the mailbox. this is a cycle which
// will leads to memory leak. in order to fix this issue, we should explicitly free the mailbox when terminate the
// actor
mailbox->SetNotifyHook(std::move(hook));
actor->Spawn(actor, std::move(mailbox));
} else {
threadPolicy.reset(new (std::nothrow) SingleThread());
MINDRT_OOM_EXIT(threadPolicy);
actor->Spawn(actor, std::move(threadPolicy));
auto mailbox = std::unique_ptr<MailBox>(new (std::nothrow) BlockingMailBox());
actor->Spawn(actor, std::move(mailbox));
ActorMgr::GetActorMgrRef()->SetActorReady(actor);
}
@ -252,27 +255,16 @@ AID ActorMgr::Spawn(const ActorReference &actor, bool shareThread, bool start) {
// long time
actor->Init();
actor->SetRunningStatus(start);
return actor->GetAID();
}
void ActorMgr::Terminate(const AID &id) {
auto actor = GetActor(id);
if (actor != nullptr) {
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageBase("Terminate", MessageBase::Type::KTERMINATE));
MINDRT_OOM_EXIT(msg);
(void)actor->EnqueMessage(std::move(msg));
actor->Terminate();
// Wait actor's thread to finish.
actor->Await();
}
}
void ActorMgr::SetActorStatus(const AID &pid, bool start) {
auto actor = GetActor(pid);
if (actor != nullptr) {
actor->SetRunningStatus(start);
RemoveActor(id.Name());
}
}

View File

@ -43,7 +43,7 @@ class ActorMgr {
static inline std::shared_ptr<IOMgr> &GetIOMgrRef(const AID &to) { return GetIOMgrRef(to.GetProtocol()); }
static void Receive(std::unique_ptr<MessageBase> &&msg) {
static void Receive(std::unique_ptr<MessageBase> msg) {
auto to = msg->To().Name();
(void)ActorMgr::GetActorMgrRef()->Send(AID(to), std::move(msg));
}
@ -58,21 +58,19 @@ class ActorMgr {
int Initialize(bool use_inner_pool = false, size_t actor_thread_num = 1, size_t max_thread_num = 1);
void RemoveActor(const std::string &name);
ActorBase *GetActor(const AID &id);
ActorReference GetActor(const AID &id);
const std::string GetUrl(const std::string &protocol = "tcp");
void AddUrl(const std::string &protocol, const std::string &url);
void AddIOMgr(const std::string &protocol, const std::shared_ptr<IOMgr> &ioMgr);
int Send(const AID &to, std::unique_ptr<MessageBase> &&msg, bool remoteLink = false, bool isExactNotRemote = false);
AID Spawn(const ActorReference &actor, bool shareThread = true, bool start = true);
int Send(const AID &to, std::unique_ptr<MessageBase> msg, bool remoteLink = false, bool isExactNotRemote = false);
AID Spawn(const ActorReference &actor, bool shareThread = true);
void Terminate(const AID &id);
void TerminateAll();
void Wait(const AID &pid);
inline const std::string &GetDelegate() const { return delegate; }
inline void SetDelegate(const std::string &d) { delegate = d; }
void SetActorReady(const ActorReference &actor) const;
void SetActorStatus(const AID &pid, bool start);
private:
inline bool IsLocalAddres(const AID &id) {
@ -82,6 +80,7 @@ class ActorMgr {
return false;
}
}
int EnqueueMessage(ActorReference actor, std::unique_ptr<MessageBase> msg);
// in order to avoid being initialized many times
std::atomic_bool initialized_{false};

View File

@ -1,121 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "actor/actor.h"
#include "actor/actormgr.h"
#include "actor/actorpolicy.h"
namespace mindspore {
void ActorPolicy::SetRunningStatus(bool startRun) {
std::lock_guard<std::mutex> lock(mailboxLock);
this->start = startRun;
Notify();
}
SingleThread::SingleThread() {}
SingleThread::~SingleThread() {}
void SingleThread::Terminate(const ActorBase *actor) {
MINDRT_OOM_EXIT(actor);
std::string actorName = actor->GetAID().Name();
MS_LOG(DEBUG) << "ACTOR SingleThread received terminate message, v=" << actorName.c_str();
// remove actor from actorMgr
ActorMgr::GetActorMgrRef()->RemoveActor(actorName);
}
int SingleThread::EnqueMessage(std::unique_ptr<MessageBase> &&msg) {
int result;
{
std::lock_guard<std::mutex> lock(mailboxLock);
enqueMailbox->push_back(std::move(msg));
result = ++msgCount;
}
// Notify when the count of message is from empty to one.
if (start && result == 1) {
conditionVar.notify_one();
}
return result;
}
void SingleThread::Notify() {
if (start && msgCount > 0) {
conditionVar.notify_one();
}
}
std::list<std::unique_ptr<MessageBase>> *SingleThread::GetMsgs() {
std::list<std::unique_ptr<MessageBase>> *result;
std::unique_lock<std::mutex> lock(mailboxLock);
conditionVar.wait(lock, [this] { return (!this->enqueMailbox->empty()); });
SwapMailbox();
// REF_PRIVATE_MEMBER
result = dequeMailbox;
return result;
}
ShardedThread::ShardedThread(const std::shared_ptr<ActorBase> &aActor)
: ready(false), terminated(false), actor(aActor) {}
ShardedThread::~ShardedThread() {}
void ShardedThread::Terminate(const ActorBase *aActor) {
std::string actorName = aActor->GetAID().Name();
mailboxLock.lock();
terminated = true;
this->actor = nullptr;
mailboxLock.unlock();
// remove actor from actorMgr
ActorMgr::GetActorMgrRef()->RemoveActor(actorName);
}
int ShardedThread::EnqueMessage(std::unique_ptr<MessageBase> &&msg) {
mailboxLock.lock();
enqueMailbox->push_back(std::move(msg));
++msgCount;
// true : The actor is running. else the actor will be ready to run.
if (start && (ready == false) && (terminated == false)) {
ActorMgr::GetActorMgrRef()->SetActorReady(actor);
ready = true;
}
mailboxLock.unlock();
return msgCount;
}
void ShardedThread::Notify() {
if (start && ready == false && terminated == false && msgCount > 0) {
ActorMgr::GetActorMgrRef()->SetActorReady(actor);
ready = true;
}
}
std::list<std::unique_ptr<MessageBase>> *ShardedThread::GetMsgs() {
std::list<std::unique_ptr<MessageBase>> *result;
mailboxLock.lock();
if (enqueMailbox->empty()) {
ready = false;
result = nullptr;
} else {
ready = true;
SwapMailbox();
result = dequeMailbox;
}
mailboxLock.unlock();
return result;
}
}; // end of namespace mindspore

View File

@ -1,61 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_MINDRT_SRC_ACTOR_ACTORPOLICY_H
#define MINDSPORE_CORE_MINDRT_SRC_ACTOR_ACTORPOLICY_H
#include <list>
#include <memory>
#include <string>
#include <utility>
#include "actor/actorpolicyinterface.h"
namespace mindspore {
class ShardedThread : public ActorPolicy {
public:
explicit ShardedThread(const std::shared_ptr<ActorBase> &actor);
virtual ~ShardedThread();
protected:
virtual void Terminate(const ActorBase *actor);
virtual int EnqueMessage(std::unique_ptr<MessageBase> &&msg);
virtual std::list<std::unique_ptr<MessageBase>> *GetMsgs();
virtual void Notify();
private:
bool ready;
bool terminated;
std::shared_ptr<ActorBase> actor;
};
class SingleThread : public ActorPolicy {
public:
SingleThread();
virtual ~SingleThread();
protected:
virtual void Terminate(const ActorBase *actor);
virtual int EnqueMessage(std::unique_ptr<MessageBase> &&msg);
virtual std::list<std::unique_ptr<MessageBase>> *GetMsgs();
virtual void Notify();
private:
std::condition_variable conditionVar;
};
}; // end of namespace mindspore
#endif

View File

@ -1,64 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_MINDRT_SRC_ACTOR_ACTORPOLICYINTERFACE_H
#define MINDSPORE_CORE_MINDRT_SRC_ACTOR_ACTORPOLICYINTERFACE_H
#include <list>
#include <memory>
namespace mindspore {
class ActorPolicy {
public:
ActorPolicy() : mailbox1(), mailbox2() {
enqueMailbox = &mailbox1;
dequeMailbox = &mailbox2;
msgCount = 0;
start = false;
}
virtual ~ActorPolicy() {}
inline void SwapMailbox() {
std::list<std::unique_ptr<MessageBase>> *temp;
temp = enqueMailbox;
enqueMailbox = dequeMailbox;
dequeMailbox = temp;
msgCount = 0;
}
protected:
void SetRunningStatus(bool startRun);
virtual void Terminate(const ActorBase *actor) = 0;
virtual int EnqueMessage(std::unique_ptr<MessageBase> &&msg) = 0;
virtual std::list<std::unique_ptr<MessageBase>> *GetMsgs() = 0;
virtual void Notify() = 0;
std::list<std::unique_ptr<MessageBase>> *enqueMailbox;
std::list<std::unique_ptr<MessageBase>> *dequeMailbox;
int msgCount;
bool start;
std::mutex mailboxLock;
private:
friend class ActorBase;
std::list<std::unique_ptr<MessageBase>> mailbox1;
std::list<std::unique_ptr<MessageBase>> mailbox2;
};
}; // end of namespace mindspore
#endif

View File

@ -0,0 +1,90 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "actor/mailbox.h"
namespace mindspore {
int BlockingMailBox::EnqueueMessage(std::unique_ptr<mindspore::MessageBase> msg) {
{
std::unique_lock<std::mutex> ulk(lock);
enqueMailBox->emplace_back(std::move(msg));
}
cond.notify_all();
return 0;
}
std::list<std::unique_ptr<MessageBase>> *BlockingMailBox::GetMsgs() {
std::list<std::unique_ptr<MessageBase>> *ret;
{
std::unique_lock<std::mutex> ulk(lock);
while (enqueMailBox->empty()) {
cond.wait(ulk, [this] { return !this->enqueMailBox->empty(); });
}
SwapMailBox(&enqueMailBox, &dequeMailBox);
ret = dequeMailBox;
}
return ret;
}
int NonblockingMailBox::EnqueueMessage(std::unique_ptr<mindspore::MessageBase> msg) {
bool empty = false;
bool released = false;
{
std::unique_lock<std::mutex> ulk(lock);
empty = enqueMailBox->empty();
enqueMailBox->emplace_back(std::move(msg));
released = this->released_;
}
if (empty && released && notifyHook) {
(*notifyHook.get())();
}
return 0;
}
std::list<std::unique_ptr<MessageBase>> *NonblockingMailBox::GetMsgs() {
std::list<std::unique_ptr<MessageBase>> *ret;
{
std::unique_lock<std::mutex> ulk(lock);
if (enqueMailBox->empty()) {
released_ = true;
return nullptr;
}
SwapMailBox(&enqueMailBox, &dequeMailBox);
ret = dequeMailBox;
released_ = false;
}
return ret;
}
int HQueMailBox::EnqueueMessage(std::unique_ptr<mindspore::MessageBase> msg) {
bool empty = mailbox.Empty();
MessageBase *msgPtr = msg.release();
while (!mailbox.Enqueue(msgPtr)) {
}
if (empty && notifyHook) {
(*notifyHook.get())();
}
return 0;
}
std::unique_ptr<MessageBase> HQueMailBox::GetMsg() {
std::unique_ptr<MessageBase> msg(mailbox.Dequeue());
return msg;
}
} // namespace mindspore

View File

@ -0,0 +1,105 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MAILBOX_H
#define MINDSPORE_MAILBOX_H
#include <list>
#include <memory>
#include <mutex>
#include <condition_variable>
#include <functional>
#include <utility>
#include "actor/msg.h"
#include "thread/hqueue.h"
namespace mindspore {
class MailBox {
public:
virtual ~MailBox() = default;
virtual int EnqueueMessage(std::unique_ptr<MessageBase> msg) = 0;
virtual std::list<std::unique_ptr<MessageBase>> *GetMsgs() = 0;
virtual std::unique_ptr<MessageBase> GetMsg() = 0;
inline void SetNotifyHook(std::unique_ptr<std::function<void()>> &&hook) { notifyHook = std::move(hook); }
inline bool TakeAllMsgsEachTime() { return takeAllMsgsEachTime; }
void SwapMailBox(std::list<std::unique_ptr<MessageBase>> **box1, std::list<std::unique_ptr<MessageBase>> **box2) {
std::list<std::unique_ptr<MessageBase>> *tmp = *box1;
*box1 = *box2;
*box2 = tmp;
}
protected:
// if this flag is true, GetMsgs() should be invoked to take all enqueued msgs each time, otherwise we can only get
// one msg by GetMsg() each time.
bool takeAllMsgsEachTime = true;
std::unique_ptr<std::function<void()>> notifyHook;
};
class BlockingMailBox : public MailBox {
public:
BlockingMailBox() : mailbox1(), mailbox2(), enqueMailBox(&mailbox1), dequeMailBox(&mailbox2) {}
virtual ~BlockingMailBox() {
mailbox1.clear();
mailbox2.clear();
}
int EnqueueMessage(std::unique_ptr<MessageBase> msg) override;
std::list<std::unique_ptr<MessageBase>> *GetMsgs() override;
std::unique_ptr<MessageBase> GetMsg() override { return nullptr; }
private:
std::list<std::unique_ptr<MessageBase>> mailbox1;
std::list<std::unique_ptr<MessageBase>> mailbox2;
std::list<std::unique_ptr<MessageBase>> *enqueMailBox;
std::list<std::unique_ptr<MessageBase>> *dequeMailBox;
std::mutex lock;
std::condition_variable cond;
};
class NonblockingMailBox : public MailBox {
public:
NonblockingMailBox() : mailbox1(), mailbox2(), enqueMailBox(&mailbox1), dequeMailBox(&mailbox2) {}
virtual ~NonblockingMailBox() {
mailbox1.clear();
mailbox2.clear();
}
int EnqueueMessage(std::unique_ptr<MessageBase> msg) override;
std::list<std::unique_ptr<MessageBase>> *GetMsgs() override;
std::unique_ptr<MessageBase> GetMsg() override { return nullptr; }
private:
std::list<std::unique_ptr<MessageBase>> mailbox1;
std::list<std::unique_ptr<MessageBase>> mailbox2;
std::list<std::unique_ptr<MessageBase>> *enqueMailBox;
std::list<std::unique_ptr<MessageBase>> *dequeMailBox;
std::mutex lock;
bool released_ = true;
};
class HQueMailBox : public MailBox {
public:
HQueMailBox() { takeAllMsgsEachTime = false; }
inline bool Init() { return mailbox.Init(MAX_MSG_QUE_SIZE); }
int EnqueueMessage(std::unique_ptr<MessageBase> msg) override;
std::list<std::unique_ptr<MessageBase>> *GetMsgs() override { return nullptr; }
std::unique_ptr<MessageBase> GetMsg() override;
private:
HQueue<MessageBase> mailbox;
static const int32_t MAX_MSG_QUE_SIZE = 4096;
};
} // namespace mindspore
#endif // MINDSPORE_MAILBOX_H

View File

@ -79,7 +79,7 @@ int Initialize(const std::string &tcpUrl, const std::string &tcpUrlAdv, const st
return result;
}
AID Spawn(const ActorReference actor, bool sharedThread, bool start) {
AID Spawn(ActorReference actor, bool sharedThread) {
if (actor == nullptr) {
MS_LOG(ERROR) << "Actor is nullptr.";
MINDRT_EXIT("Actor is nullptr.");
@ -88,17 +88,16 @@ AID Spawn(const ActorReference actor, bool sharedThread, bool start) {
if (local::g_finalizeMindrtStatus.load() == true) {
return actor->GetAID();
} else {
return ActorMgr::GetActorMgrRef()->Spawn(actor, sharedThread, start);
return ActorMgr::GetActorMgrRef()->Spawn(actor, sharedThread);
}
}
void SetActorStatus(const AID &actor, bool start) { ActorMgr::GetActorMgrRef()->SetActorStatus(actor, start); }
void Await(const ActorReference &actor) { ActorMgr::GetActorMgrRef()->Wait(actor->GetAID()); }
void Await(const AID &actor) { ActorMgr::GetActorMgrRef()->Wait(actor); }
// brief get actor with aid
ActorBase *GetActor(const AID &actor) { return ActorMgr::GetActorMgrRef()->GetActor(actor); }
ActorReference GetActor(const AID &actor) { return ActorMgr::GetActorMgrRef()->GetActor(actor); }
void Terminate(const AID &actor) { ActorMgr::GetActorMgrRef()->Terminate(actor); }

View File

@ -140,7 +140,7 @@ void ActorThreadPool::PushActorToQueue(ActorBase *actor) {
int ActorThreadPool::CreateThreads(size_t actor_thread_num, size_t all_thread_num, const std::vector<int> &core_list) {
#ifdef USE_HQUEUE
if (actor_queue_.Init(MAX_READY_ACTOR_NR) != THREAD_OK) {
if (actor_queue_.Init(MAX_READY_ACTOR_NR) != true) {
THREAD_ERROR("init actor queue failed.");
return THREAD_ERROR;
}

View File

@ -18,8 +18,6 @@
#define MINDSPORE_CORE_MINDRT_RUNTIME_HQUEUE_H_
#include <atomic>
#include <vector>
#include "actor/log.h"
#include "thread/threadlog.h"
namespace mindspore {
// implement a lock-free queue
@ -50,10 +48,13 @@ class HQueue {
HQueue() {}
virtual ~HQueue() {}
int Init(int32_t sz) {
bool Init(int32_t sz) {
for (int32_t i = 0; i < sz; i++) {
auto node = new HQNode<T>();
THREAD_ERROR_IF_NULL(node);
if (node == nullptr) {
Clean();
return false;
}
node->value = nullptr;
node->free = true;
node->next = {-1, 0};
@ -64,7 +65,7 @@ class HQueue {
qhead = {0, 0};
qtail = {0, 0};
nodes[0]->free = false;
return THREAD_OK;
return true;
}
void Clean() {
@ -163,6 +164,7 @@ class HQueue {
return false;
}
private:
std::atomic<Pointer> qhead;
std::atomic<Pointer> qtail;
std::vector<HQNode<T> *> nodes;