forked from mindspore-Ecosystem/mindspore
!20841 [MS][LITE]implement shared thread policy with hqueue in mindrt
Merge pull request !20841 from zhaizhiqiang/master
This commit is contained in:
commit
a73befd070
|
@ -60,7 +60,7 @@ class SwitchActor : public AbstractActor {
|
|||
std::vector<KernelWithIndex> formal_parameters_;
|
||||
|
||||
// Input data.
|
||||
std::unordered_map<uuids::uuid *, std::unordered_map<size_t, std::vector<KernelWithIndex>>> input_nodes_;
|
||||
std::unordered_map<int, std::unordered_map<size_t, std::vector<KernelWithIndex>>> input_nodes_;
|
||||
// The store node records the value node input of the switch actor.
|
||||
std::vector<std::pair<size_t, AnfNodePtr>> store_nodes_;
|
||||
|
||||
|
|
|
@ -173,8 +173,7 @@ void MemoryManagerActor::SetOpContextMemoryAllocFail(const std::string &kernel_n
|
|||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
MS_EXCEPTION_IF_NULL(op_context);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(op_context->sequential_num_);
|
||||
auto step_id = uuids::uuid::ToBytes(*(op_context->sequential_num_));
|
||||
int step_id = op_context->sequential_num_;
|
||||
// First occur allocating memory failed.
|
||||
if (mem_alloc_failed_step_ids_.find(step_id) == mem_alloc_failed_step_ids_.end()) {
|
||||
mem_alloc_failed_step_ids_.clear();
|
||||
|
|
|
@ -70,7 +70,7 @@ class MemoryManagerActor : public ActorBase {
|
|||
// MemoryManagerActor object is used like a single instance, if one actor allocates memory failed in one batch, which
|
||||
// will set fail message info OpContext, major thread will destroy the OpContext object, subsequent actor can not set
|
||||
// fail message again, so we record allocating memory fail event by the uuid of the batch, which is key of the set.
|
||||
std::set<std::string> mem_alloc_failed_step_ids_;
|
||||
std::set<int> mem_alloc_failed_step_ids_;
|
||||
};
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -328,9 +328,8 @@ void GraphScheduler::Run(const ActorSet *actor_set, const std::vector<std::vecto
|
|||
|
||||
// Construct OpContext.
|
||||
OpContext<DeviceTensor> op_context;
|
||||
uuids::uuid sequential_num;
|
||||
std::vector<Promise<int>> result(1);
|
||||
op_context.sequential_num_ = &sequential_num;
|
||||
op_context.sequential_num_ = RandInt::Instance().Get();
|
||||
op_context.results_ = &result;
|
||||
|
||||
if ((strategy == GraphExecutionStrategy::kStep) && IsSingleOpActorSet(actor_set)) {
|
||||
|
|
|
@ -24,14 +24,13 @@
|
|||
#include <string>
|
||||
#include <utility>
|
||||
#include "thread/hqueue.h"
|
||||
|
||||
#include "actor/msg.h"
|
||||
#include "actor/mailbox.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
class ActorBase;
|
||||
class ActorMgr;
|
||||
class ActorPolicy;
|
||||
class ActorWorker;
|
||||
class ActorThreadPool;
|
||||
|
||||
|
@ -40,7 +39,7 @@ using ActorReference = std::shared_ptr<ActorBase>;
|
|||
// should be at least greater than 1
|
||||
constexpr uint32_t MAX_ACTOR_RECORD_SIZE = 3;
|
||||
|
||||
class ActorBase : std::enable_shared_from_this<ActorBase> {
|
||||
class ActorBase {
|
||||
public:
|
||||
inline const AID &GetAID() const { return id; }
|
||||
|
||||
|
@ -58,6 +57,7 @@ class ActorBase : std::enable_shared_from_this<ActorBase> {
|
|||
startPoint = (startPoint + MAX_ACTOR_RECORD_SIZE - 1) % MAX_ACTOR_RECORD_SIZE;
|
||||
}
|
||||
}
|
||||
|
||||
ActorBase();
|
||||
explicit ActorBase(const std::string &name);
|
||||
explicit ActorBase(const std::string &name, ActorThreadPool *pool);
|
||||
|
@ -97,12 +97,12 @@ class ActorBase : std::enable_shared_from_this<ActorBase> {
|
|||
virtual void Finalize() {}
|
||||
|
||||
// KHTTPMsg handler
|
||||
virtual void HandleHttp(std::unique_ptr<MessageBase> msg) {
|
||||
virtual void HandleHttp(const std::unique_ptr<MessageBase> &msg) {
|
||||
MS_LOG(ERROR) << "ACTOR (" << id.Name().c_str() << ") HandleHttp() is not implemented";
|
||||
}
|
||||
|
||||
// KLOCALMsg handler
|
||||
virtual void HandleLocalMsg(std::unique_ptr<MessageBase> msg) {
|
||||
virtual void HandleLocalMsg(const std::unique_ptr<MessageBase> &msg) {
|
||||
MS_LOG(ERROR) << "ACTOR (" << id.Name().c_str() << ") HandleLocalMsg() is not implemented.";
|
||||
}
|
||||
|
||||
|
@ -200,12 +200,12 @@ class ActorBase : std::enable_shared_from_this<ActorBase> {
|
|||
|
||||
void Run();
|
||||
void Quit();
|
||||
int EnqueMessage(std::unique_ptr<MessageBase> &&msg);
|
||||
int EnqueMessage(std::unique_ptr<MessageBase> msg);
|
||||
|
||||
void Spawn(const std::shared_ptr<ActorBase> &actor, std::unique_ptr<ActorPolicy> actorThread);
|
||||
void SetRunningStatus(bool start);
|
||||
void Spawn(const std::shared_ptr<ActorBase> &actor, std::unique_ptr<MailBox> mailbox);
|
||||
|
||||
std::unique_ptr<ActorPolicy> actorPolicy;
|
||||
std::unique_ptr<MailBox> mailbox;
|
||||
std::atomic_bool terminating_ = false;
|
||||
|
||||
AID id;
|
||||
std::map<std::string, ActorFunction> actionFunctions;
|
||||
|
|
|
@ -27,6 +27,7 @@ constexpr int ERRORCODE_SUCCESS = 1;
|
|||
constexpr int ACTOR_PARAMER_ERR = -101;
|
||||
constexpr int ACTOR_NOT_FIND = -102;
|
||||
constexpr int IO_NOT_FIND = -103;
|
||||
constexpr int ACTOR_TERMINATED = -104;
|
||||
|
||||
// TCP module err code -301 ~ -400
|
||||
// Null
|
||||
|
|
|
@ -48,6 +48,18 @@ struct OpData {
|
|||
int index_;
|
||||
};
|
||||
|
||||
class RandInt {
|
||||
public:
|
||||
int Get() { return rand(); }
|
||||
static RandInt &Instance() {
|
||||
static RandInt instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
private:
|
||||
RandInt() { srand(time(NULL)); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using OpDataPtr = std::shared_ptr<OpData<T>>;
|
||||
|
||||
|
@ -57,7 +69,7 @@ using OpDataUniquePtr = std::unique_ptr<OpData<T>>;
|
|||
// The context of opActor running.
|
||||
template <typename T>
|
||||
struct OpContext {
|
||||
uuids::uuid *sequential_num_;
|
||||
int sequential_num_;
|
||||
std::vector<OpDataPtr<T>> *output_data_;
|
||||
std::vector<Promise<int>> *results_;
|
||||
// Record the error info for print.
|
||||
|
@ -100,11 +112,11 @@ class OpActor : public ActorBase {
|
|||
|
||||
protected:
|
||||
// The op data.
|
||||
std::unordered_map<uuids::uuid *, std::vector<OpData<T> *>> input_op_datas_;
|
||||
std::unordered_map<int, std::vector<OpData<T> *>> input_op_datas_;
|
||||
std::vector<DataArrowPtr> output_data_arrows_;
|
||||
|
||||
// The op controls.
|
||||
std::unordered_map<uuids::uuid *, std::vector<AID *>> input_op_controls_;
|
||||
std::unordered_map<int, std::vector<AID *>> input_op_controls_;
|
||||
std::vector<AID> output_control_arrows_;
|
||||
};
|
||||
|
||||
|
@ -128,8 +140,7 @@ int MindrtRun(const std::vector<OpDataPtr<T>> &input_data, std::vector<OpDataPtr
|
|||
const void *kernel_call_back_before, const void *kernel_call_back_after) {
|
||||
OpContext<T> context;
|
||||
std::vector<Promise<int>> promises(output_data->size());
|
||||
uuids::uuid uid;
|
||||
context.sequential_num_ = &uid;
|
||||
context.sequential_num_ = RandInt::Instance().Get();
|
||||
context.results_ = &promises;
|
||||
context.output_data_ = output_data;
|
||||
context.kernel_call_back_before_ = kernel_call_back_before;
|
||||
|
|
|
@ -33,8 +33,8 @@ using MessageHandler = std::function<void(ActorBase *)>;
|
|||
|
||||
class MessageAsync : public MessageBase {
|
||||
public:
|
||||
explicit MessageAsync(MessageHandler &h) : MessageBase("Async", Type::KASYNC), handler(h) {}
|
||||
~MessageAsync() override {}
|
||||
explicit MessageAsync(MessageHandler &&h) : MessageBase("Async", Type::KASYNC), handler(h) {}
|
||||
virtual ~MessageAsync() = default;
|
||||
void Run(ActorBase *actor) override { (handler)(actor); }
|
||||
|
||||
private:
|
||||
|
@ -52,7 +52,7 @@ struct AsyncHelper<void> {
|
|||
template <typename F>
|
||||
void operator()(const AID &aid, F &&f) {
|
||||
std::function<void(ActorBase *)> handler = [=](ActorBase *) { f(); };
|
||||
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
|
||||
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
|
||||
MINDRT_OOM_EXIT(msg);
|
||||
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
|
||||
}
|
||||
|
@ -68,7 +68,7 @@ struct AsyncHelper<Future<R>> {
|
|||
|
||||
MessageHandler handler = [=](ActorBase *) { promise->Associate(f()); };
|
||||
|
||||
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
|
||||
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
|
||||
MINDRT_OOM_EXIT(msg);
|
||||
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
|
||||
return future;
|
||||
|
@ -84,7 +84,7 @@ struct AsyncHelper {
|
|||
Future<R> future = promise->GetFuture();
|
||||
|
||||
std::function<void(ActorBase *)> handler = [=](ActorBase *) { promise->SetValue(f()); };
|
||||
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
|
||||
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
|
||||
MINDRT_OOM_EXIT(msg);
|
||||
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
|
||||
return future;
|
||||
|
@ -102,7 +102,7 @@ void Async(const AID &aid, void (T::*method)()) {
|
|||
MINDRT_ASSERT(t != nullptr);
|
||||
(t->*method)();
|
||||
};
|
||||
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
|
||||
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
|
||||
MINDRT_OOM_EXIT(msg);
|
||||
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
|
||||
}
|
||||
|
@ -115,7 +115,7 @@ void Async(const AID &aid, void (T::*method)(Arg0), Arg1 &&arg) {
|
|||
MINDRT_ASSERT(t != nullptr);
|
||||
(t->*method)(arg);
|
||||
};
|
||||
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
|
||||
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
|
||||
MINDRT_OOM_EXIT(msg);
|
||||
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
|
||||
}
|
||||
|
@ -128,7 +128,7 @@ void Async(const AID &aid, void (T::*method)(Args0...), std::tuple<Args1...> &&t
|
|||
MINDRT_ASSERT(t != nullptr);
|
||||
Apply(t, method, tuple);
|
||||
};
|
||||
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
|
||||
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
|
||||
MINDRT_OOM_EXIT(msg);
|
||||
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
|
||||
}
|
||||
|
@ -152,7 +152,7 @@ Future<R> Async(const AID &aid, Future<R> (T::*method)()) {
|
|||
MINDRT_ASSERT(t != nullptr);
|
||||
promise->Associate((t->*method)());
|
||||
};
|
||||
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
|
||||
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
|
||||
MINDRT_OOM_EXIT(msg);
|
||||
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
|
||||
return future;
|
||||
|
@ -171,7 +171,7 @@ Future<R> Async(const AID &aid, Future<R> (T::*method)(Arg0), Arg1 &&arg) {
|
|||
promise->Associate((t->*method)(arg));
|
||||
};
|
||||
|
||||
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
|
||||
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
|
||||
MINDRT_OOM_EXIT(msg);
|
||||
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
|
||||
return future;
|
||||
|
@ -190,7 +190,7 @@ Future<R> Async(const AID &aid, Future<R> (T::*method)(Args0...), std::tuple<Arg
|
|||
promise->Associate(Apply(t, method, tuple));
|
||||
};
|
||||
|
||||
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
|
||||
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
|
||||
MINDRT_OOM_EXIT(msg);
|
||||
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
|
||||
return future;
|
||||
|
@ -217,7 +217,7 @@ Future<R> Async(const AID &aid, R (T::*method)()) {
|
|||
promise->SetValue((t->*method)());
|
||||
};
|
||||
|
||||
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
|
||||
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
|
||||
MINDRT_OOM_EXIT(msg);
|
||||
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
|
||||
return future;
|
||||
|
@ -237,7 +237,7 @@ Future<R> Async(const AID &aid, R (T::*method)(Arg0), Arg1 &&arg) {
|
|||
MINDRT_ASSERT(t != nullptr);
|
||||
promise->SetValue((t->*method)(arg));
|
||||
};
|
||||
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
|
||||
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
|
||||
MINDRT_OOM_EXIT(msg);
|
||||
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
|
||||
return future;
|
||||
|
@ -257,7 +257,7 @@ Future<R> Async(const AID &aid, R (T::*method)(Args0...), std::tuple<Args1...> &
|
|||
MINDRT_ASSERT(t != nullptr);
|
||||
promise->SetValue(Apply(t, method, tuple));
|
||||
};
|
||||
std::unique_ptr<MessageAsync> msg(new (std::nothrow) MessageAsync(handler));
|
||||
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageAsync(std::move(handler)));
|
||||
MINDRT_OOM_EXIT(msg);
|
||||
(void)ActorMgr::GetActorMgrRef()->Send(aid, std::move(msg));
|
||||
return future;
|
||||
|
|
|
@ -33,13 +33,13 @@ int Initialize(const std::string &tcpUrl, const std::string &tcpUrlAdv = "", con
|
|||
const std::string &udpUrlAdv = "", int threadCount = 0);
|
||||
|
||||
// brief spawn a process to run an actor
|
||||
AID Spawn(ActorReference actor, bool sharedThread = true, bool start = true);
|
||||
AID Spawn(ActorReference actor, bool sharedThread = true);
|
||||
|
||||
// brief wait for the actor process to exit . It will be discarded
|
||||
void Await(const ActorReference &actor);
|
||||
|
||||
// brief get actor with aid
|
||||
ActorBase *GetActor(const AID &actor);
|
||||
ActorReference GetActor(const AID &actor);
|
||||
|
||||
// brief wait for the actor process to exit
|
||||
void Await(const AID &actor);
|
||||
|
@ -47,9 +47,6 @@ void Await(const AID &actor);
|
|||
// brief Terminate the actor to exit
|
||||
void Terminate(const AID &actor);
|
||||
|
||||
// brief set the actor's running status
|
||||
void SetActorStatus(const AID &actor, bool start);
|
||||
|
||||
// brief Terminate all actors
|
||||
void TerminateAll();
|
||||
|
||||
|
|
|
@ -16,28 +16,24 @@
|
|||
|
||||
#include "actor/actor.h"
|
||||
#include "actor/actormgr.h"
|
||||
#include "actor/actorpolicyinterface.h"
|
||||
#include "actor/iomgr.h"
|
||||
|
||||
namespace mindspore {
|
||||
ActorBase::ActorBase() : actorPolicy(nullptr), id("", ActorMgr::GetActorMgrRef()->GetUrl()), actionFunctions() {}
|
||||
ActorBase::ActorBase() : mailbox(nullptr), id("", ActorMgr::GetActorMgrRef()->GetUrl()), actionFunctions() {}
|
||||
|
||||
ActorBase::ActorBase(const std::string &name)
|
||||
: actorPolicy(nullptr), id(name, ActorMgr::GetActorMgrRef()->GetUrl()), actionFunctions() {}
|
||||
: mailbox(nullptr), id(name, ActorMgr::GetActorMgrRef()->GetUrl()), actionFunctions() {}
|
||||
|
||||
ActorBase::ActorBase(const std::string &name, ActorThreadPool *pool)
|
||||
: actorPolicy(nullptr), id(name, ActorMgr::GetActorMgrRef()->GetUrl()), actionFunctions(), pool_(pool) {}
|
||||
: mailbox(nullptr), id(name, ActorMgr::GetActorMgrRef()->GetUrl()), actionFunctions(), pool_(pool) {}
|
||||
|
||||
ActorBase::~ActorBase() {}
|
||||
|
||||
void ActorBase::Spawn(const std::shared_ptr<ActorBase> &actor, std::unique_ptr<ActorPolicy> thread) {
|
||||
// lock here or await(). and unlock at Quit() or at aweit.
|
||||
void ActorBase::Spawn(const std::shared_ptr<ActorBase> &actor, std::unique_ptr<MailBox> mailboxPtr) {
|
||||
// lock here or await(). and unlock at Quit() or at await.
|
||||
waiterLock.lock();
|
||||
|
||||
actorPolicy = std::move(thread);
|
||||
MINDRT_OOM_EXIT(actorPolicy);
|
||||
this->mailbox = std::move(mailboxPtr);
|
||||
}
|
||||
void ActorBase::SetRunningStatus(bool start) { actorPolicy->SetRunningStatus(start); }
|
||||
|
||||
void ActorBase::Await() {
|
||||
std::string actorName = id.Name();
|
||||
|
@ -46,12 +42,19 @@ void ActorBase::Await() {
|
|||
|
||||
waiterLock.lock();
|
||||
waiterLock.unlock();
|
||||
|
||||
// mailbox's hook may hold the actor reference, we need explicitly free the mailbox to avoid the memory leak. the
|
||||
// details can refer to the comments in ActorMgr::Spawn
|
||||
delete mailbox.release();
|
||||
MS_LOG(DEBUG) << "ACTOR succeeded in waiting. a=" << actorName.c_str();
|
||||
}
|
||||
void ActorBase::Terminate() {
|
||||
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageBase("Terminate", MessageBase::Type::KTERMINATE));
|
||||
MINDRT_OOM_EXIT(msg);
|
||||
(void)EnqueMessage(std::move(msg));
|
||||
bool flag = false;
|
||||
if (terminating_.compare_exchange_strong(flag, true)) {
|
||||
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageBase("Terminate", MessageBase::Type::KTERMINATE));
|
||||
MINDRT_OOM_EXIT(msg);
|
||||
(void)EnqueMessage(std::move(msg));
|
||||
}
|
||||
}
|
||||
|
||||
void ActorBase::HandlekMsg(const std::unique_ptr<MessageBase> &msg) {
|
||||
|
@ -64,61 +67,76 @@ void ActorBase::HandlekMsg(const std::unique_ptr<MessageBase> &msg) {
|
|||
<< ",m=" << msg->Name().c_str();
|
||||
}
|
||||
}
|
||||
int ActorBase::EnqueMessage(std::unique_ptr<MessageBase> &&msg) { return actorPolicy->EnqueMessage(std::move(msg)); }
|
||||
int ActorBase::EnqueMessage(std::unique_ptr<MessageBase> msg) {
|
||||
int ret = mailbox->EnqueueMessage(std::move(msg));
|
||||
return ret;
|
||||
}
|
||||
|
||||
void ActorBase::Quit() {
|
||||
Finalize();
|
||||
// lock at spawn(), unlock here.
|
||||
waiterLock.unlock();
|
||||
|
||||
actorPolicy->Terminate(this);
|
||||
}
|
||||
|
||||
void ActorBase::Run() {
|
||||
for (;;) {
|
||||
auto msgs = actorPolicy->GetMsgs();
|
||||
if (msgs == nullptr) {
|
||||
return;
|
||||
}
|
||||
for (auto it = msgs->begin(); it != msgs->end(); ++it) {
|
||||
std::unique_ptr<MessageBase> &msg = *it;
|
||||
if (msg == nullptr) {
|
||||
continue;
|
||||
auto msgHandler = [this](const std::unique_ptr<MessageBase> &msg) {
|
||||
AddMsgRecord(msg->Name());
|
||||
switch (msg->GetType()) {
|
||||
case MessageBase::Type::KMSG:
|
||||
case MessageBase::Type::KUDP: {
|
||||
if (Filter(msg)) {
|
||||
return ERRORCODE_SUCCESS;
|
||||
}
|
||||
this->HandlekMsg(msg);
|
||||
return ERRORCODE_SUCCESS;
|
||||
}
|
||||
AddMsgRecord(msg->Name());
|
||||
switch (msg->GetType()) {
|
||||
case MessageBase::Type::KMSG:
|
||||
case MessageBase::Type::KUDP: {
|
||||
if (Filter(msg)) {
|
||||
continue;
|
||||
}
|
||||
this->HandlekMsg(msg);
|
||||
break;
|
||||
case MessageBase::Type::KHTTP: {
|
||||
this->HandleHttp(msg);
|
||||
return ERRORCODE_SUCCESS;
|
||||
}
|
||||
case MessageBase::Type::KASYNC: {
|
||||
msg->Run(this);
|
||||
return ERRORCODE_SUCCESS;
|
||||
}
|
||||
case MessageBase::Type::KLOCAL: {
|
||||
this->HandleLocalMsg(msg);
|
||||
return ERRORCODE_SUCCESS;
|
||||
}
|
||||
case MessageBase::Type::KTERMINATE: {
|
||||
this->Quit();
|
||||
return ACTOR_TERMINATED;
|
||||
}
|
||||
case MessageBase::Type::KEXIT: {
|
||||
this->Exited(msg->From());
|
||||
return ERRORCODE_SUCCESS;
|
||||
}
|
||||
}
|
||||
return ERRORCODE_SUCCESS;
|
||||
};
|
||||
|
||||
if (this->mailbox->TakeAllMsgsEachTime()) {
|
||||
while (auto msgs = mailbox->GetMsgs()) {
|
||||
for (auto it = msgs->begin(); it != msgs->end(); ++it) {
|
||||
std::unique_ptr<MessageBase> &msg = *it;
|
||||
if (msg == nullptr) {
|
||||
continue;
|
||||
}
|
||||
case MessageBase::Type::KHTTP: {
|
||||
this->HandleHttp(std::move(msg));
|
||||
break;
|
||||
}
|
||||
case MessageBase::Type::KASYNC: {
|
||||
msg->Run(this);
|
||||
break;
|
||||
}
|
||||
case MessageBase::Type::KLOCAL: {
|
||||
this->HandleLocalMsg(std::move(msg));
|
||||
break;
|
||||
}
|
||||
case MessageBase::Type::KTERMINATE: {
|
||||
this->Quit();
|
||||
MS_LOG_DEBUG << "dequeue message]actor=" << id.Name() << ",msg=" << msg->Name();
|
||||
if (msgHandler(msg) == ACTOR_TERMINATED) {
|
||||
return;
|
||||
}
|
||||
case MessageBase::Type::KEXIT: {
|
||||
this->Exited(msg->From());
|
||||
break;
|
||||
}
|
||||
}
|
||||
msgs->clear();
|
||||
}
|
||||
|
||||
} else {
|
||||
while (auto msg = mailbox->GetMsg()) {
|
||||
if (msgHandler(msg) == ACTOR_TERMINATED) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
msgs->clear();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int ActorBase::Send(const AID &to, std::unique_ptr<MessageBase> msg) {
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include <utility>
|
||||
|
||||
#include "actor/actormgr.h"
|
||||
#include "actor/actorpolicy.h"
|
||||
#include "actor/iomgr.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -134,10 +133,7 @@ void ActorMgr::TerminateAll() {
|
|||
|
||||
// send terminal msg to all actors.
|
||||
for (auto actorIt = actorsWaiting.begin(); actorIt != actorsWaiting.end(); ++actorIt) {
|
||||
(*actorIt)->SetRunningStatus(true);
|
||||
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageBase("Terminate", MessageBase::Type::KTERMINATE));
|
||||
MINDRT_OOM_EXIT(msg);
|
||||
(void)(*actorIt)->EnqueMessage(std::move(msg));
|
||||
(*actorIt)->Terminate();
|
||||
}
|
||||
|
||||
// wait actor's thread to finish.
|
||||
|
@ -165,7 +161,7 @@ void ActorMgr::Finalize() {
|
|||
MS_LOG(INFO) << "mindrt IOMGRS finish exiting.";
|
||||
}
|
||||
|
||||
ActorBase *ActorMgr::GetActor(const AID &id) {
|
||||
ActorReference ActorMgr::GetActor(const AID &id) {
|
||||
#ifndef MS_COMPILE_IOS
|
||||
actorsMutex.lock_shared();
|
||||
#else
|
||||
|
@ -179,7 +175,7 @@ ActorBase *ActorMgr::GetActor(const AID &id) {
|
|||
#else
|
||||
actorsMutex.unlock();
|
||||
#endif
|
||||
return result.get();
|
||||
return result;
|
||||
} else {
|
||||
#ifndef MS_COMPILE_IOS
|
||||
actorsMutex.unlock_shared();
|
||||
|
@ -191,7 +187,11 @@ ActorBase *ActorMgr::GetActor(const AID &id) {
|
|||
}
|
||||
}
|
||||
|
||||
int ActorMgr::Send(const AID &to, std::unique_ptr<MessageBase> &&msg, bool remoteLink, bool isExactNotRemote) {
|
||||
int ActorMgr::EnqueueMessage(mindspore::ActorReference actor, std::unique_ptr<mindspore::MessageBase> msg) {
|
||||
return actor->EnqueMessage(std::move(msg));
|
||||
}
|
||||
|
||||
int ActorMgr::Send(const AID &to, std::unique_ptr<MessageBase> msg, bool remoteLink, bool isExactNotRemote) {
|
||||
// The destination is local
|
||||
if (IsLocalAddres(to)) {
|
||||
auto actor = GetActor(to);
|
||||
|
@ -199,7 +199,7 @@ int ActorMgr::Send(const AID &to, std::unique_ptr<MessageBase> &&msg, bool remot
|
|||
if (to.GetProtocol() == MINDRT_UDP && msg->GetType() == MessageBase::Type::KMSG) {
|
||||
msg->type = MessageBase::Type::KUDP;
|
||||
}
|
||||
return actor->EnqueMessage(std::move(msg));
|
||||
return EnqueueMessage(actor, std::move(msg));
|
||||
} else {
|
||||
return ACTOR_NOT_FIND;
|
||||
}
|
||||
|
@ -224,7 +224,7 @@ int ActorMgr::Send(const AID &to, std::unique_ptr<MessageBase> &&msg, bool remot
|
|||
}
|
||||
}
|
||||
|
||||
AID ActorMgr::Spawn(const ActorReference &actor, bool shareThread, bool start) {
|
||||
AID ActorMgr::Spawn(const ActorReference &actor, bool shareThread) {
|
||||
actorsMutex.lock();
|
||||
if (actors.find(actor->GetAID().Name()) != actors.end()) {
|
||||
actorsMutex.unlock();
|
||||
|
@ -232,17 +232,20 @@ AID ActorMgr::Spawn(const ActorReference &actor, bool shareThread, bool start) {
|
|||
MINDRT_EXIT("Actor name conflicts.");
|
||||
}
|
||||
MS_LOG(DEBUG) << "ACTOR was spawned,a=" << actor->GetAID().Name().c_str();
|
||||
std::unique_ptr<ActorPolicy> threadPolicy;
|
||||
|
||||
if (shareThread) {
|
||||
threadPolicy.reset(new (std::nothrow) ShardedThread(actor));
|
||||
MINDRT_OOM_EXIT(threadPolicy);
|
||||
actor->Spawn(actor, std::move(threadPolicy));
|
||||
auto mailbox = std::unique_ptr<MailBox>(new (std::nothrow) NonblockingMailBox());
|
||||
auto hook = std::unique_ptr<std::function<void()>>(
|
||||
new std::function<void()>([actor]() { ActorMgr::GetActorMgrRef()->SetActorReady(actor); }));
|
||||
// the mailbox has this hook, the hook holds the actor reference, the actor has the mailbox. this is a cycle which
|
||||
// will leads to memory leak. in order to fix this issue, we should explicitly free the mailbox when terminate the
|
||||
// actor
|
||||
mailbox->SetNotifyHook(std::move(hook));
|
||||
actor->Spawn(actor, std::move(mailbox));
|
||||
|
||||
} else {
|
||||
threadPolicy.reset(new (std::nothrow) SingleThread());
|
||||
MINDRT_OOM_EXIT(threadPolicy);
|
||||
actor->Spawn(actor, std::move(threadPolicy));
|
||||
auto mailbox = std::unique_ptr<MailBox>(new (std::nothrow) BlockingMailBox());
|
||||
actor->Spawn(actor, std::move(mailbox));
|
||||
ActorMgr::GetActorMgrRef()->SetActorReady(actor);
|
||||
}
|
||||
|
||||
|
@ -252,27 +255,16 @@ AID ActorMgr::Spawn(const ActorReference &actor, bool shareThread, bool start) {
|
|||
// long time
|
||||
actor->Init();
|
||||
|
||||
actor->SetRunningStatus(start);
|
||||
|
||||
return actor->GetAID();
|
||||
}
|
||||
|
||||
void ActorMgr::Terminate(const AID &id) {
|
||||
auto actor = GetActor(id);
|
||||
if (actor != nullptr) {
|
||||
std::unique_ptr<MessageBase> msg(new (std::nothrow) MessageBase("Terminate", MessageBase::Type::KTERMINATE));
|
||||
MINDRT_OOM_EXIT(msg);
|
||||
(void)actor->EnqueMessage(std::move(msg));
|
||||
|
||||
actor->Terminate();
|
||||
// Wait actor's thread to finish.
|
||||
actor->Await();
|
||||
}
|
||||
}
|
||||
|
||||
void ActorMgr::SetActorStatus(const AID &pid, bool start) {
|
||||
auto actor = GetActor(pid);
|
||||
if (actor != nullptr) {
|
||||
actor->SetRunningStatus(start);
|
||||
RemoveActor(id.Name());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ class ActorMgr {
|
|||
|
||||
static inline std::shared_ptr<IOMgr> &GetIOMgrRef(const AID &to) { return GetIOMgrRef(to.GetProtocol()); }
|
||||
|
||||
static void Receive(std::unique_ptr<MessageBase> &&msg) {
|
||||
static void Receive(std::unique_ptr<MessageBase> msg) {
|
||||
auto to = msg->To().Name();
|
||||
(void)ActorMgr::GetActorMgrRef()->Send(AID(to), std::move(msg));
|
||||
}
|
||||
|
@ -58,21 +58,19 @@ class ActorMgr {
|
|||
int Initialize(bool use_inner_pool = false, size_t actor_thread_num = 1, size_t max_thread_num = 1);
|
||||
|
||||
void RemoveActor(const std::string &name);
|
||||
ActorBase *GetActor(const AID &id);
|
||||
ActorReference GetActor(const AID &id);
|
||||
const std::string GetUrl(const std::string &protocol = "tcp");
|
||||
void AddUrl(const std::string &protocol, const std::string &url);
|
||||
void AddIOMgr(const std::string &protocol, const std::shared_ptr<IOMgr> &ioMgr);
|
||||
int Send(const AID &to, std::unique_ptr<MessageBase> &&msg, bool remoteLink = false, bool isExactNotRemote = false);
|
||||
AID Spawn(const ActorReference &actor, bool shareThread = true, bool start = true);
|
||||
int Send(const AID &to, std::unique_ptr<MessageBase> msg, bool remoteLink = false, bool isExactNotRemote = false);
|
||||
AID Spawn(const ActorReference &actor, bool shareThread = true);
|
||||
void Terminate(const AID &id);
|
||||
void TerminateAll();
|
||||
void Wait(const AID &pid);
|
||||
inline const std::string &GetDelegate() const { return delegate; }
|
||||
|
||||
inline void SetDelegate(const std::string &d) { delegate = d; }
|
||||
|
||||
void SetActorReady(const ActorReference &actor) const;
|
||||
void SetActorStatus(const AID &pid, bool start);
|
||||
|
||||
private:
|
||||
inline bool IsLocalAddres(const AID &id) {
|
||||
|
@ -82,6 +80,7 @@ class ActorMgr {
|
|||
return false;
|
||||
}
|
||||
}
|
||||
int EnqueueMessage(ActorReference actor, std::unique_ptr<MessageBase> msg);
|
||||
// in order to avoid being initialized many times
|
||||
std::atomic_bool initialized_{false};
|
||||
|
||||
|
|
|
@ -1,121 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "actor/actor.h"
|
||||
#include "actor/actormgr.h"
|
||||
#include "actor/actorpolicy.h"
|
||||
|
||||
namespace mindspore {
|
||||
void ActorPolicy::SetRunningStatus(bool startRun) {
|
||||
std::lock_guard<std::mutex> lock(mailboxLock);
|
||||
this->start = startRun;
|
||||
Notify();
|
||||
}
|
||||
|
||||
SingleThread::SingleThread() {}
|
||||
SingleThread::~SingleThread() {}
|
||||
|
||||
void SingleThread::Terminate(const ActorBase *actor) {
|
||||
MINDRT_OOM_EXIT(actor);
|
||||
std::string actorName = actor->GetAID().Name();
|
||||
MS_LOG(DEBUG) << "ACTOR SingleThread received terminate message, v=" << actorName.c_str();
|
||||
// remove actor from actorMgr
|
||||
ActorMgr::GetActorMgrRef()->RemoveActor(actorName);
|
||||
}
|
||||
|
||||
int SingleThread::EnqueMessage(std::unique_ptr<MessageBase> &&msg) {
|
||||
int result;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mailboxLock);
|
||||
enqueMailbox->push_back(std::move(msg));
|
||||
result = ++msgCount;
|
||||
}
|
||||
// Notify when the count of message is from empty to one.
|
||||
if (start && result == 1) {
|
||||
conditionVar.notify_one();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
void SingleThread::Notify() {
|
||||
if (start && msgCount > 0) {
|
||||
conditionVar.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
std::list<std::unique_ptr<MessageBase>> *SingleThread::GetMsgs() {
|
||||
std::list<std::unique_ptr<MessageBase>> *result;
|
||||
std::unique_lock<std::mutex> lock(mailboxLock);
|
||||
conditionVar.wait(lock, [this] { return (!this->enqueMailbox->empty()); });
|
||||
SwapMailbox();
|
||||
|
||||
// REF_PRIVATE_MEMBER
|
||||
result = dequeMailbox;
|
||||
return result;
|
||||
}
|
||||
|
||||
ShardedThread::ShardedThread(const std::shared_ptr<ActorBase> &aActor)
|
||||
: ready(false), terminated(false), actor(aActor) {}
|
||||
ShardedThread::~ShardedThread() {}
|
||||
|
||||
void ShardedThread::Terminate(const ActorBase *aActor) {
|
||||
std::string actorName = aActor->GetAID().Name();
|
||||
|
||||
mailboxLock.lock();
|
||||
terminated = true;
|
||||
this->actor = nullptr;
|
||||
mailboxLock.unlock();
|
||||
|
||||
// remove actor from actorMgr
|
||||
ActorMgr::GetActorMgrRef()->RemoveActor(actorName);
|
||||
}
|
||||
|
||||
int ShardedThread::EnqueMessage(std::unique_ptr<MessageBase> &&msg) {
|
||||
mailboxLock.lock();
|
||||
enqueMailbox->push_back(std::move(msg));
|
||||
++msgCount;
|
||||
|
||||
// true : The actor is running. else the actor will be ready to run.
|
||||
if (start && (ready == false) && (terminated == false)) {
|
||||
ActorMgr::GetActorMgrRef()->SetActorReady(actor);
|
||||
ready = true;
|
||||
}
|
||||
mailboxLock.unlock();
|
||||
return msgCount;
|
||||
}
|
||||
|
||||
void ShardedThread::Notify() {
|
||||
if (start && ready == false && terminated == false && msgCount > 0) {
|
||||
ActorMgr::GetActorMgrRef()->SetActorReady(actor);
|
||||
ready = true;
|
||||
}
|
||||
}
|
||||
|
||||
std::list<std::unique_ptr<MessageBase>> *ShardedThread::GetMsgs() {
|
||||
std::list<std::unique_ptr<MessageBase>> *result;
|
||||
mailboxLock.lock();
|
||||
|
||||
if (enqueMailbox->empty()) {
|
||||
ready = false;
|
||||
result = nullptr;
|
||||
} else {
|
||||
ready = true;
|
||||
SwapMailbox();
|
||||
result = dequeMailbox;
|
||||
}
|
||||
mailboxLock.unlock();
|
||||
return result;
|
||||
}
|
||||
}; // end of namespace mindspore
|
|
@ -1,61 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_MINDRT_SRC_ACTOR_ACTORPOLICY_H
|
||||
#define MINDSPORE_CORE_MINDRT_SRC_ACTOR_ACTORPOLICY_H
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "actor/actorpolicyinterface.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
class ShardedThread : public ActorPolicy {
|
||||
public:
|
||||
explicit ShardedThread(const std::shared_ptr<ActorBase> &actor);
|
||||
virtual ~ShardedThread();
|
||||
|
||||
protected:
|
||||
virtual void Terminate(const ActorBase *actor);
|
||||
virtual int EnqueMessage(std::unique_ptr<MessageBase> &&msg);
|
||||
virtual std::list<std::unique_ptr<MessageBase>> *GetMsgs();
|
||||
virtual void Notify();
|
||||
|
||||
private:
|
||||
bool ready;
|
||||
bool terminated;
|
||||
std::shared_ptr<ActorBase> actor;
|
||||
};
|
||||
|
||||
class SingleThread : public ActorPolicy {
|
||||
public:
|
||||
SingleThread();
|
||||
virtual ~SingleThread();
|
||||
|
||||
protected:
|
||||
virtual void Terminate(const ActorBase *actor);
|
||||
virtual int EnqueMessage(std::unique_ptr<MessageBase> &&msg);
|
||||
virtual std::list<std::unique_ptr<MessageBase>> *GetMsgs();
|
||||
virtual void Notify();
|
||||
|
||||
private:
|
||||
std::condition_variable conditionVar;
|
||||
};
|
||||
|
||||
}; // end of namespace mindspore
|
||||
#endif
|
|
@ -1,64 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_MINDRT_SRC_ACTOR_ACTORPOLICYINTERFACE_H
|
||||
#define MINDSPORE_CORE_MINDRT_SRC_ACTOR_ACTORPOLICYINTERFACE_H
|
||||
|
||||
#include <list>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
class ActorPolicy {
|
||||
public:
|
||||
ActorPolicy() : mailbox1(), mailbox2() {
|
||||
enqueMailbox = &mailbox1;
|
||||
dequeMailbox = &mailbox2;
|
||||
msgCount = 0;
|
||||
start = false;
|
||||
}
|
||||
virtual ~ActorPolicy() {}
|
||||
inline void SwapMailbox() {
|
||||
std::list<std::unique_ptr<MessageBase>> *temp;
|
||||
temp = enqueMailbox;
|
||||
enqueMailbox = dequeMailbox;
|
||||
dequeMailbox = temp;
|
||||
msgCount = 0;
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetRunningStatus(bool startRun);
|
||||
virtual void Terminate(const ActorBase *actor) = 0;
|
||||
virtual int EnqueMessage(std::unique_ptr<MessageBase> &&msg) = 0;
|
||||
virtual std::list<std::unique_ptr<MessageBase>> *GetMsgs() = 0;
|
||||
virtual void Notify() = 0;
|
||||
|
||||
std::list<std::unique_ptr<MessageBase>> *enqueMailbox;
|
||||
std::list<std::unique_ptr<MessageBase>> *dequeMailbox;
|
||||
|
||||
int msgCount;
|
||||
bool start;
|
||||
std::mutex mailboxLock;
|
||||
|
||||
private:
|
||||
friend class ActorBase;
|
||||
|
||||
std::list<std::unique_ptr<MessageBase>> mailbox1;
|
||||
std::list<std::unique_ptr<MessageBase>> mailbox2;
|
||||
};
|
||||
|
||||
}; // end of namespace mindspore
|
||||
#endif
|
|
@ -0,0 +1,90 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "actor/mailbox.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
int BlockingMailBox::EnqueueMessage(std::unique_ptr<mindspore::MessageBase> msg) {
|
||||
{
|
||||
std::unique_lock<std::mutex> ulk(lock);
|
||||
enqueMailBox->emplace_back(std::move(msg));
|
||||
}
|
||||
|
||||
cond.notify_all();
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::list<std::unique_ptr<MessageBase>> *BlockingMailBox::GetMsgs() {
|
||||
std::list<std::unique_ptr<MessageBase>> *ret;
|
||||
{
|
||||
std::unique_lock<std::mutex> ulk(lock);
|
||||
while (enqueMailBox->empty()) {
|
||||
cond.wait(ulk, [this] { return !this->enqueMailBox->empty(); });
|
||||
}
|
||||
SwapMailBox(&enqueMailBox, &dequeMailBox);
|
||||
ret = dequeMailBox;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int NonblockingMailBox::EnqueueMessage(std::unique_ptr<mindspore::MessageBase> msg) {
|
||||
bool empty = false;
|
||||
bool released = false;
|
||||
{
|
||||
std::unique_lock<std::mutex> ulk(lock);
|
||||
empty = enqueMailBox->empty();
|
||||
enqueMailBox->emplace_back(std::move(msg));
|
||||
released = this->released_;
|
||||
}
|
||||
if (empty && released && notifyHook) {
|
||||
(*notifyHook.get())();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::list<std::unique_ptr<MessageBase>> *NonblockingMailBox::GetMsgs() {
|
||||
std::list<std::unique_ptr<MessageBase>> *ret;
|
||||
{
|
||||
std::unique_lock<std::mutex> ulk(lock);
|
||||
if (enqueMailBox->empty()) {
|
||||
released_ = true;
|
||||
return nullptr;
|
||||
}
|
||||
SwapMailBox(&enqueMailBox, &dequeMailBox);
|
||||
ret = dequeMailBox;
|
||||
released_ = false;
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
int HQueMailBox::EnqueueMessage(std::unique_ptr<mindspore::MessageBase> msg) {
|
||||
bool empty = mailbox.Empty();
|
||||
MessageBase *msgPtr = msg.release();
|
||||
while (!mailbox.Enqueue(msgPtr)) {
|
||||
}
|
||||
if (empty && notifyHook) {
|
||||
(*notifyHook.get())();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::unique_ptr<MessageBase> HQueMailBox::GetMsg() {
|
||||
std::unique_ptr<MessageBase> msg(mailbox.Dequeue());
|
||||
return msg;
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,105 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_MAILBOX_H
|
||||
#define MINDSPORE_MAILBOX_H
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include "actor/msg.h"
|
||||
#include "thread/hqueue.h"
|
||||
|
||||
namespace mindspore {
|
||||
class MailBox {
|
||||
public:
|
||||
virtual ~MailBox() = default;
|
||||
virtual int EnqueueMessage(std::unique_ptr<MessageBase> msg) = 0;
|
||||
virtual std::list<std::unique_ptr<MessageBase>> *GetMsgs() = 0;
|
||||
virtual std::unique_ptr<MessageBase> GetMsg() = 0;
|
||||
inline void SetNotifyHook(std::unique_ptr<std::function<void()>> &&hook) { notifyHook = std::move(hook); }
|
||||
inline bool TakeAllMsgsEachTime() { return takeAllMsgsEachTime; }
|
||||
void SwapMailBox(std::list<std::unique_ptr<MessageBase>> **box1, std::list<std::unique_ptr<MessageBase>> **box2) {
|
||||
std::list<std::unique_ptr<MessageBase>> *tmp = *box1;
|
||||
*box1 = *box2;
|
||||
*box2 = tmp;
|
||||
}
|
||||
|
||||
protected:
|
||||
// if this flag is true, GetMsgs() should be invoked to take all enqueued msgs each time, otherwise we can only get
|
||||
// one msg by GetMsg() each time.
|
||||
bool takeAllMsgsEachTime = true;
|
||||
std::unique_ptr<std::function<void()>> notifyHook;
|
||||
};
|
||||
|
||||
class BlockingMailBox : public MailBox {
|
||||
public:
|
||||
BlockingMailBox() : mailbox1(), mailbox2(), enqueMailBox(&mailbox1), dequeMailBox(&mailbox2) {}
|
||||
virtual ~BlockingMailBox() {
|
||||
mailbox1.clear();
|
||||
mailbox2.clear();
|
||||
}
|
||||
int EnqueueMessage(std::unique_ptr<MessageBase> msg) override;
|
||||
std::list<std::unique_ptr<MessageBase>> *GetMsgs() override;
|
||||
std::unique_ptr<MessageBase> GetMsg() override { return nullptr; }
|
||||
|
||||
private:
|
||||
std::list<std::unique_ptr<MessageBase>> mailbox1;
|
||||
std::list<std::unique_ptr<MessageBase>> mailbox2;
|
||||
std::list<std::unique_ptr<MessageBase>> *enqueMailBox;
|
||||
std::list<std::unique_ptr<MessageBase>> *dequeMailBox;
|
||||
std::mutex lock;
|
||||
std::condition_variable cond;
|
||||
};
|
||||
|
||||
class NonblockingMailBox : public MailBox {
|
||||
public:
|
||||
NonblockingMailBox() : mailbox1(), mailbox2(), enqueMailBox(&mailbox1), dequeMailBox(&mailbox2) {}
|
||||
virtual ~NonblockingMailBox() {
|
||||
mailbox1.clear();
|
||||
mailbox2.clear();
|
||||
}
|
||||
int EnqueueMessage(std::unique_ptr<MessageBase> msg) override;
|
||||
std::list<std::unique_ptr<MessageBase>> *GetMsgs() override;
|
||||
std::unique_ptr<MessageBase> GetMsg() override { return nullptr; }
|
||||
|
||||
private:
|
||||
std::list<std::unique_ptr<MessageBase>> mailbox1;
|
||||
std::list<std::unique_ptr<MessageBase>> mailbox2;
|
||||
std::list<std::unique_ptr<MessageBase>> *enqueMailBox;
|
||||
std::list<std::unique_ptr<MessageBase>> *dequeMailBox;
|
||||
std::mutex lock;
|
||||
bool released_ = true;
|
||||
};
|
||||
|
||||
class HQueMailBox : public MailBox {
|
||||
public:
|
||||
HQueMailBox() { takeAllMsgsEachTime = false; }
|
||||
inline bool Init() { return mailbox.Init(MAX_MSG_QUE_SIZE); }
|
||||
int EnqueueMessage(std::unique_ptr<MessageBase> msg) override;
|
||||
std::list<std::unique_ptr<MessageBase>> *GetMsgs() override { return nullptr; }
|
||||
std::unique_ptr<MessageBase> GetMsg() override;
|
||||
|
||||
private:
|
||||
HQueue<MessageBase> mailbox;
|
||||
static const int32_t MAX_MSG_QUE_SIZE = 4096;
|
||||
};
|
||||
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_MAILBOX_H
|
|
@ -79,7 +79,7 @@ int Initialize(const std::string &tcpUrl, const std::string &tcpUrlAdv, const st
|
|||
return result;
|
||||
}
|
||||
|
||||
AID Spawn(const ActorReference actor, bool sharedThread, bool start) {
|
||||
AID Spawn(ActorReference actor, bool sharedThread) {
|
||||
if (actor == nullptr) {
|
||||
MS_LOG(ERROR) << "Actor is nullptr.";
|
||||
MINDRT_EXIT("Actor is nullptr.");
|
||||
|
@ -88,17 +88,16 @@ AID Spawn(const ActorReference actor, bool sharedThread, bool start) {
|
|||
if (local::g_finalizeMindrtStatus.load() == true) {
|
||||
return actor->GetAID();
|
||||
} else {
|
||||
return ActorMgr::GetActorMgrRef()->Spawn(actor, sharedThread, start);
|
||||
return ActorMgr::GetActorMgrRef()->Spawn(actor, sharedThread);
|
||||
}
|
||||
}
|
||||
void SetActorStatus(const AID &actor, bool start) { ActorMgr::GetActorMgrRef()->SetActorStatus(actor, start); }
|
||||
|
||||
void Await(const ActorReference &actor) { ActorMgr::GetActorMgrRef()->Wait(actor->GetAID()); }
|
||||
|
||||
void Await(const AID &actor) { ActorMgr::GetActorMgrRef()->Wait(actor); }
|
||||
|
||||
// brief get actor with aid
|
||||
ActorBase *GetActor(const AID &actor) { return ActorMgr::GetActorMgrRef()->GetActor(actor); }
|
||||
ActorReference GetActor(const AID &actor) { return ActorMgr::GetActorMgrRef()->GetActor(actor); }
|
||||
|
||||
void Terminate(const AID &actor) { ActorMgr::GetActorMgrRef()->Terminate(actor); }
|
||||
|
||||
|
|
|
@ -140,7 +140,7 @@ void ActorThreadPool::PushActorToQueue(ActorBase *actor) {
|
|||
|
||||
int ActorThreadPool::CreateThreads(size_t actor_thread_num, size_t all_thread_num, const std::vector<int> &core_list) {
|
||||
#ifdef USE_HQUEUE
|
||||
if (actor_queue_.Init(MAX_READY_ACTOR_NR) != THREAD_OK) {
|
||||
if (actor_queue_.Init(MAX_READY_ACTOR_NR) != true) {
|
||||
THREAD_ERROR("init actor queue failed.");
|
||||
return THREAD_ERROR;
|
||||
}
|
||||
|
|
|
@ -18,8 +18,6 @@
|
|||
#define MINDSPORE_CORE_MINDRT_RUNTIME_HQUEUE_H_
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
#include "actor/log.h"
|
||||
#include "thread/threadlog.h"
|
||||
|
||||
namespace mindspore {
|
||||
// implement a lock-free queue
|
||||
|
@ -50,10 +48,13 @@ class HQueue {
|
|||
HQueue() {}
|
||||
virtual ~HQueue() {}
|
||||
|
||||
int Init(int32_t sz) {
|
||||
bool Init(int32_t sz) {
|
||||
for (int32_t i = 0; i < sz; i++) {
|
||||
auto node = new HQNode<T>();
|
||||
THREAD_ERROR_IF_NULL(node);
|
||||
if (node == nullptr) {
|
||||
Clean();
|
||||
return false;
|
||||
}
|
||||
node->value = nullptr;
|
||||
node->free = true;
|
||||
node->next = {-1, 0};
|
||||
|
@ -64,7 +65,7 @@ class HQueue {
|
|||
qhead = {0, 0};
|
||||
qtail = {0, 0};
|
||||
nodes[0]->free = false;
|
||||
return THREAD_OK;
|
||||
return true;
|
||||
}
|
||||
|
||||
void Clean() {
|
||||
|
@ -163,6 +164,7 @@ class HQueue {
|
|||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
std::atomic<Pointer> qhead;
|
||||
std::atomic<Pointer> qtail;
|
||||
std::vector<HQNode<T> *> nodes;
|
||||
|
|
Loading…
Reference in New Issue