!19656 [MS][LITE]optimize mindrt with lock-free queue

Merge pull request !19656 from zhaizhiqiang/master
This commit is contained in:
i-robot 2021-07-10 06:59:00 +00:00 committed by Gitee
commit 0ae9af3199
7 changed files with 244 additions and 177 deletions

View File

@ -23,6 +23,7 @@
#include <mutex>
#include <string>
#include <utility>
#include "thread/hqueue.h"
#include "actor/msg.h"
@ -39,7 +40,7 @@ using ActorReference = std::shared_ptr<ActorBase>;
// should be at least greater than 1
constexpr uint32_t MAX_ACTOR_RECORD_SIZE = 3;
class ActorBase {
class ActorBase : std::enable_shared_from_this<ActorBase> {
public:
inline const AID &GetAID() const { return id; }
@ -57,7 +58,7 @@ class ActorBase {
startPoint = (startPoint + MAX_ACTOR_RECORD_SIZE - 1) % MAX_ACTOR_RECORD_SIZE;
}
}
ActorBase();
explicit ActorBase(const std::string &name);
explicit ActorBase(const std::string &name, ActorThreadPool *pool);
virtual ~ActorBase();

View File

@ -20,6 +20,7 @@
#include "actor/iomgr.h"
namespace mindspore {
ActorBase::ActorBase() : actorPolicy(nullptr), id("", ActorMgr::GetActorMgrRef()->GetUrl()), actionFunctions() {}
ActorBase::ActorBase(const std::string &name)
: actorPolicy(nullptr), id(name, ActorMgr::GetActorMgrRef()->GetUrl()), actionFunctions() {}

View File

@ -27,6 +27,7 @@
#endif
#include "actor/actor.h"
#include "thread/actor_threadpool.h"
#include "thread/hqueue.h"
namespace mindspore {
@ -64,13 +65,13 @@ class ActorMgr {
inline const std::string &GetDelegate() const { return delegate; }
inline void SetDelegate(const std::string &d) { delegate = d; }
inline void SetActorReady(const std::shared_ptr<ActorBase> &actor) const {
inline void SetActorReady(std::shared_ptr<ActorBase> &actor) const {
auto pool = actor->pool_;
if (pool == nullptr) {
MS_LOG(ERROR) << "ThreadPOol is nullptr, actor: " << actor->GetAID().Name();
MS_LOG(ERROR) << "ThreadPool is nullptr, actor: " << actor->GetAID().Name();
return;
}
pool->PushActorToQueue(actor);
pool->PushActorToQueue(actor.get());
}
void SetActorStatus(const AID &pid, bool start);

View File

@ -18,6 +18,7 @@
#include "thread/core_affinity.h"
namespace mindspore {
constexpr size_t MAX_READY_ACTOR_NR = 1024;
void ActorWorker::CreateThread(ActorThreadPool *pool) {
THREAD_RETURN_IF_NULL(pool);
pool_ = pool;
@ -54,13 +55,10 @@ bool ActorWorker::RunQueueActorTask() {
}
bool ActorWorker::Active() {
{
std::lock_guard<std::mutex> _l(mutex_);
if (status_ != kThreadIdle) {
return false;
}
status_ = kThreadBusy;
if (status_ != kThreadIdle) {
return false;
}
status_ = kThreadBusy;
cond_var_.notify_one();
return true;
}
@ -70,8 +68,12 @@ ActorThreadPool::~ActorThreadPool() {
bool terminate = false;
do {
{
#ifdef USE_HQUEUE
terminate = actor_queue_.Empty();
#else
std::lock_guard<std::mutex> _l(actor_mutex_);
terminate = actor_queue_.empty();
#endif
}
if (!terminate) {
std::this_thread::yield();
@ -82,9 +84,15 @@ ActorThreadPool::~ActorThreadPool() {
worker = nullptr;
}
workers_.clear();
#ifdef USE_HQUEUE
actor_queue_.Clean();
#endif
}
ActorReference ActorThreadPool::PopActorFromQueue() {
ActorBase *ActorThreadPool::PopActorFromQueue() {
#ifdef USE_HQUEUE
return actor_queue_.Dequeue();
#else
std::lock_guard<std::mutex> _l(actor_mutex_);
if (actor_queue_.empty()) {
return nullptr;
@ -92,12 +100,21 @@ ActorReference ActorThreadPool::PopActorFromQueue() {
auto actor = actor_queue_.front();
actor_queue_.pop();
return actor;
#endif
}
void ActorThreadPool::PushActorToQueue(const ActorReference &actor) {
void ActorThreadPool::PushActorToQueue(ActorBase *actor) {
if (!actor) {
return;
}
{
#ifdef USE_HQUEUE
while (!actor_queue_.Enqueue(actor)) {
}
#else
std::lock_guard<std::mutex> _l(actor_mutex_);
actor_queue_.push(actor);
#endif
}
THREAD_INFO("actor[%s] enqueue success", actor->GetAID().Name().c_str());
// active one idle actor thread if exist
@ -110,6 +127,10 @@ void ActorThreadPool::PushActorToQueue(const ActorReference &actor) {
}
int ActorThreadPool::CreateThreads(size_t actor_thread_num, size_t all_thread_num) {
#ifdef USE_HQUEUE
actor_queue_.Init(MAX_READY_ACTOR_NR);
#endif
size_t core_num = std::thread::hardware_concurrency();
THREAD_INFO("ThreadInfo, Actor: [%zu], All: [%zu], CoreNum: [%zu]", actor_thread_num, all_thread_num, core_num);
actor_thread_num_ = actor_thread_num < core_num ? actor_thread_num : core_num;

View File

@ -24,9 +24,8 @@
#include "thread/threadpool.h"
#include "actor/actor.h"
#include "thread/hqueue.h"
#define USE_HQUEUE
namespace mindspore {
class ActorThreadPool;
class ActorWorker : public Worker {
@ -49,18 +48,21 @@ class ActorThreadPool : public ThreadPool {
static ActorThreadPool *CreateThreadPool(size_t thread_num);
~ActorThreadPool() override;
void PushActorToQueue(const ActorReference &actor);
ActorReference PopActorFromQueue();
void PushActorToQueue(ActorBase *actor);
ActorBase *PopActorFromQueue();
private:
ActorThreadPool() {}
int CreateThreads(size_t actor_thread_num, size_t all_thread_num);
size_t actor_thread_num_{0};
std::mutex actor_mutex_;
// std::condition_variable actor_cond_;
std::queue<ActorReference> actor_queue_;
std::condition_variable actor_cond_;
#ifdef USE_HQUEUE
HQueue<ActorBase> actor_queue_;
#else
std::queue<ActorBase *> actor_queue_;
#endif
};
} // namespace mindspore
#endif // MINDSPORE_CORE_MINDRT_RUNTIME_ACTOR_THREADPOOL_H_

View File

@ -18,17 +18,27 @@
#define MINDSPORE_CORE_MINDRT_RUNTIME_HQUEUE_H_
#include <atomic>
#include <vector>
#include "actor/log.h"
namespace mindspore {
// implement a lock-free queue
// refer to https://www.cs.rochester.edu/u/scott/papers/1996_PODC_queues.pdf
template <typename T>
class HQueue;
struct Pointer {
int32_t index = -1;
uint32_t version = 0;
bool operator==(const Pointer &that) { return (index == that.index && version == that.version); }
bool operator!=(const Pointer &that) { return !(*this == that); }
};
template <typename T>
struct HQNode {
HQNode() {}
HQNode(const T &t_, HQNode<T> *n) : t(t_), next(n) {}
T t;
std::atomic<HQNode<T> *> next = nullptr;
std::atomic<Pointer> next;
T *value = nullptr;
std::atomic_bool free = {true};
};
template <typename T>
@ -37,94 +47,123 @@ class HQueue {
HQueue(const HQueue &) = delete;
HQueue &operator=(const HQueue &) = delete;
HQueue() {}
virtual ~HQueue() {
// delete dummy head
HQNode<T> *node = this->qhead;
delete node;
}
virtual ~HQueue() {}
bool Init() {
HQNode<T> *dummyHead = new HQNode<T>();
if (!dummyHead) {
return false;
}
qhead = dummyHead;
qtail = dummyHead;
return true;
}
bool Enqueue(const T &data) {
HQNode<T> *node = new HQNode<T>(data, nullptr);
if (!node) {
return false;
void Init(int32_t sz) {
for (int32_t i = 0; i < sz; i++) {
auto node = new HQNode<T>();
node->value = nullptr;
node->free = true;
node->next = {-1, 0};
nodes.template emplace_back(node);
}
HQNode<T> *tail = nullptr;
HQNode<T> *next = nullptr;
// init first node as dummy head
qhead = {0, 0};
qtail = {0, 0};
nodes[0]->free = false;
return;
}
void Clean() {
for (auto node : nodes) {
delete node;
}
nodes.clear();
}
bool Enqueue(T *t) {
HQNode<T> *node = nullptr;
int32_t nodeIdx;
for (nodeIdx = 0; nodeIdx < static_cast<int32_t>(nodes.size()); nodeIdx++) {
bool expected = true;
if (nodes[nodeIdx]->free.compare_exchange_strong(expected, false)) {
node = nodes[nodeIdx];
break;
}
}
if (node == nullptr) {
return false;
}
node->value = t;
node->next = {-1, 0};
while (true) {
tail = this->qtail;
next = tail->next;
Pointer tail = qtail;
if (tail.index == -1) {
continue;
}
Pointer next = nodes[tail.index]->next;
if (tail != this->qtail) {
continue;
}
if (next == nullptr) {
if (tail->next.compare_exchange_strong(next, node)) {
break;
}
} else {
this->qtail.compare_exchange_strong(tail, next);
if (next.index != -1) {
this->qtail.compare_exchange_strong(tail, {next.index, tail.version + 1});
continue;
}
if (nodes[tail.index]->next.compare_exchange_strong(next, {nodeIdx, next.version + 1})) {
this->qtail.compare_exchange_strong(tail, {nodeIdx, tail.version + 1});
break;
}
}
this->qtail.compare_exchange_weak(tail, node);
return true;
}
bool Dequeue(T *data) {
HQNode<T> *head = nullptr;
HQNode<T> *tail = nullptr;
HQNode<T> *next = nullptr;
T *Dequeue() {
while (true) {
head = this->qhead;
tail = this->qtail;
next = head->next;
T *ret = nullptr;
Pointer head = qhead;
Pointer tail = qtail;
if (head.index == -1) {
continue;
}
Pointer next = nodes[head.index]->next;
if (head != this->qhead) {
continue;
}
if (head == tail) {
if (next == nullptr) {
return false;
if (head.index == tail.index) {
if (next.index == -1) {
return nullptr;
}
this->qtail.compare_exchange_strong(tail, next);
this->qtail.compare_exchange_strong(tail, {next.index, tail.version + 1});
} else {
*data = next->t;
if (this->qhead.compare_exchange_strong(head, next)) {
break;
if (next.index == -1) {
continue;
}
ret = nodes[next.index]->value;
if (this->qhead.compare_exchange_strong(head, {next.index, head.version + 1})) {
// free head
nodes[head.index]->free = true;
return ret;
}
}
}
delete head;
return true;
}
bool Empty() {
HQNode<T> *head = this->qhead;
HQNode<T> *tail = this->qtail;
HQNode<T> *next = head->next;
if (head == this->qhead && head == tail && next == nullptr) {
Pointer head = qhead;
Pointer tail = qtail;
if (head.index < 0) {
return false;
}
Pointer next = nodes[head.index]->next;
return true;
if (head == this->qhead && head.index == tail.index && next.index == -1) {
return true;
}
return false;
}
private:
std::atomic<HQNode<T> *> qhead;
std::atomic<HQNode<T> *> qtail;
std::atomic<Pointer> qhead;
std::atomic<Pointer> qtail;
std::vector<HQNode<T> *> nodes;
};
} // namespace mindspore

View File

@ -13,13 +13,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// #include <sys/time.h>
#include "actor/actor.h"
#include "actor/op_actor.h"
#include "async/uuid_base.h"
#include "async/future.h"
#include "src/lite_mindrt.h"
#include "thread/hqueue.h"
#include "thread/actor_threadpool.h"
#include "common/common_test.h"
#include "schema/model_generated.h"
#include "include/model.h"
namespace mindspore {
class LiteMindRtTest : public mindspore::CommonTest {
@ -27,71 +31,71 @@ class LiteMindRtTest : public mindspore::CommonTest {
LiteMindRtTest() {}
};
TEST_F(LiteMindRtTest, HQueueTest) {
HQueue<int *> hq;
hq.Init();
std::vector<int *> v1(2000);
int d1 = 1;
for (size_t s = 0; s < v1.size(); s++) {
v1[s] = new int(d1);
}
std::vector<int *> v2(2000);
int d2 = 2;
for (size_t s = 0; s < v2.size(); s++) {
v2[s] = new int(d2);
}
std::thread t1([&]() {
for (size_t s = 0; s < v1.size(); s++) {
hq.Enqueue(v1[s]);
}
});
std::thread t2([&]() {
for (size_t s = 0; s < v2.size(); s++) {
hq.Enqueue(v2[s]);
}
});
int c1 = 0;
int c2 = 0;
std::thread t3([&]() {
size_t loop = v1.size() + v2.size();
while (loop) {
int *val = nullptr;
hq.Dequeue(&val);
if (val == nullptr) {
continue;
}
loop--;
if (*val == d1) {
c1++;
} else if (*val == d2) {
c2++;
} else {
// should never come here
ASSERT_EQ(0, 1);
}
}
});
t1.join();
t2.join();
t3.join();
ASSERT_EQ(c1, v1.size());
ASSERT_EQ(c2, v2.size());
int *tmp = nullptr;
ASSERT_EQ(hq.Dequeue(&tmp), false);
for (size_t s = 0; s < v1.size(); s++) {
delete v1[s];
}
for (size_t s = 0; s < v2.size(); s++) {
delete v2[s];
}
}
// TEST_F(LiteMindRtTest, HQueueTest) {
// HQueue<int *> hq;
// hq.Init();
// std::vector<int *> v1(2000);
// int d1 = 1;
// for (size_t s = 0; s < v1.size(); s++) {
// v1[s] = new int(d1);
// }
// std::vector<int *> v2(2000);
// int d2 = 2;
// for (size_t s = 0; s < v2.size(); s++) {
// v2[s] = new int(d2);
// }
//
// std::thread t1([&]() {
// for (size_t s = 0; s < v1.size(); s++) {
// hq.Enqueue(v1[s]);
// }
// });
// std::thread t2([&]() {
// for (size_t s = 0; s < v2.size(); s++) {
// hq.Enqueue(v2[s]);
// }
// });
//
// int c1 = 0;
// int c2 = 0;
//
// std::thread t3([&]() {
// size_t loop = v1.size() + v2.size();
// while (loop) {
// int *val = nullptr;
// hq.Dequeue(&val);
// if (val == nullptr) {
// continue;
// }
// loop--;
// if (*val == d1) {
// c1++;
// } else if (*val == d2) {
// c2++;
// } else {
// // should never come here
// ASSERT_EQ(0, 1);
// }
// }
// });
//
// t1.join();
// t2.join();
// t3.join();
//
// ASSERT_EQ(c1, v1.size());
// ASSERT_EQ(c2, v2.size());
// int *tmp = nullptr;
// ASSERT_EQ(hq.Dequeue(&tmp), false);
//
// for (size_t s = 0; s < v1.size(); s++) {
// delete v1[s];
// }
//
// for (size_t s = 0; s < v2.size(); s++) {
// delete v2[s];
// }
//}
class TestActor : public ActorBase {
public:
@ -109,46 +113,44 @@ class TestActor : public ActorBase {
};
TEST_F(LiteMindRtTest, ActorThreadPoolTest) {
Initialize("", "", "", "", 4);
auto pool = ActorThreadPool::CreateThreadPool(4);
AID t1 = Spawn(ActorReference(new TestActor("t1", pool, 1)));
AID t2 = Spawn(ActorReference(new TestActor("t2", pool, 2)));
AID t3 = Spawn(ActorReference(new TestActor("t3", pool, 3)));
AID t4 = Spawn(ActorReference(new TestActor("t4", pool, 4)));
AID t5 = Spawn(ActorReference(new TestActor("t5", pool, 5)));
AID t6 = Spawn(ActorReference(new TestActor("t6", pool, 6)));
Initialize("", "", "", "", 40);
auto pool = ActorThreadPool::CreateThreadPool(40);
std::vector<AID> actors;
for (size_t i = 0; i < 200; i++) {
AID t1 = Spawn(ActorReference(new TestActor(to_string(i), pool, i)));
actors.emplace_back(t1);
}
std::vector<int *> vv;
std::vector<Future<int>> fv;
size_t sz = 2000;
std::vector<int> expected;
size_t sz = 300;
// struct timeval start, end;
// gettimeofday(&start, NULL);
for (size_t i = 0; i < sz; i++) {
vv.emplace_back(new int(i));
int data = 0;
for (auto a : actors) {
int *val = new int(i);
vv.emplace_back(val);
Future<int> ret;
ret = Async(a, &TestActor::Fn1, val) // (*vv[i])++;
.Then(Defer(a, &TestActor::Fn2, val), ret); // t2.data += (*vv[i]);
fv.emplace_back(ret);
expected.emplace_back(data + i + 1);
data++;
}
}
for (size_t i = 0; i < sz; i++) {
int *val = vv[i];
Future<int> ret;
ret = Async(t1, &TestActor::Fn1, val) // (*vv[i])++;
.Then(Defer(t2, &TestActor::Fn2, val), ret) // t2.data += (*vv[i]);
.Then(Defer(t3, &TestActor::Fn1, val), ret) // (*vv[i])++;
.Then(Defer(t4, &TestActor::Fn2, val), ret) // t4.data += (*vv[i]);
.Then(Defer(t5, &TestActor::Fn1, val), ret) // (*vv[i])++;
.Then(Defer(t6, &TestActor::Fn2, val), ret); // t6.data += (*vv[i]);
fv.emplace_back(ret);
}
for (size_t i = 0; i < vv.size(); i++) {
int val = static_cast<int>(i);
int expected = 0;
val += 3; // t1.Fn1
expected = 6; // t6.data
expected += val;
ASSERT_EQ(fv[i].Get(), expected);
ASSERT_EQ(*vv[i], val);
int ret = fv[i].Get();
ASSERT_EQ(ret, expected[i]);
}
// gettimeofday(&end, NULL);
//
// std::cout << "start: " << start.tv_sec << "." << start.tv_usec << std::endl;
// std::cout << "end: " << end.tv_sec << "." << end.tv_usec << std::endl;
// std::cout << "consumed: " << (end.tv_sec - start.tv_sec) * 1000000 + (end.tv_usec - start.tv_usec) << " us"
// << std::endl;
Finalize();