Add operator parallelism

This commit is contained in:
gongdaguo 2022-02-28 14:38:57 +08:00
parent e6d4a725ea
commit 5cb747cf4b
13 changed files with 386 additions and 45 deletions

View File

@ -3,6 +3,7 @@
# Scene2:
# file_path:function_name1, function_name2
#
mindspore/mindspore/core/mindrt/src/thread/actor_threadpool.cc:mindspore::ActorWorker::RunWithSpin
mindspore/mindspore/lite/src/ops/primitive_c.cc:mindspore::lite::PrimitiveC::Create
mindspore/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc:mindspore::dataset::CsvOp::CsvParser::InitCsvParser
mindspore/mindspore/lite/tools/converter/graphdef_transform.cc:mindspore::lite::GraphDefTransform::Transform

View File

@ -22,11 +22,7 @@
namespace mindspore {
constexpr size_t MAX_READY_ACTOR_NR = 4096;
void ActorWorker::CreateThread(ActorThreadPool *pool) {
THREAD_RETURN_IF_NULL(pool);
pool_ = pool;
thread_ = std::thread(&ActorWorker::RunWithSpin, this);
}
void ActorWorker::CreateThread() { thread_ = std::thread(&ActorWorker::RunWithSpin, this); }
void ActorWorker::RunWithSpin() {
SetAffinity();
@ -41,24 +37,43 @@ void ActorWorker::RunWithSpin() {
#endif
while (alive_) {
// only run either local KernelTask or PoolQueue ActorTask
if (RunLocalKernelTask() || RunQueueActorTask()) {
if (RunLocalKernelTask()) {
spin_count_ = 0;
} else {
YieldAndDeactive();
}
#ifdef OPERATOR_PARALLELISM
if (RunQueueActorTask() || RunQueueWorkTask()) {
#else
if (RunQueueActorTask()) {
#endif
if (spin_count_ > 0) {
spin_count_ = 1;
}
}
if (spin_count_ > max_spin_count_) {
WaitUntilActive();
spin_count_ = 0;
spin_count_ = 1;
}
}
}
bool ActorWorker::RunQueueActorTask() {
THREAD_ERROR_IF_NULL(pool_);
auto actor = pool_->PopActorFromQueue();
if (pool_ == nullptr) {
return false;
}
auto actor = reinterpret_cast<ActorThreadPool *>(pool_)->PopActorFromQueue();
if (actor == nullptr) {
return false;
}
#ifndef OPERATOR_PARALLELISM
if (available() || check_task_nullptr()) {
status_ = kThreadBusy;
set_task_free(true);
} else {
set_task_free(false);
}
#endif
actor->Run();
return true;
}
@ -149,6 +164,12 @@ int ActorThreadPool::CreateThreads(size_t actor_thread_num, size_t all_thread_nu
THREAD_ERROR("init actor queue failed.");
return THREAD_ERROR;
}
#ifdef OPERATOR_PARALLELISM
if (task_queue_.Init(MAX_READY_TASK_NR) != true) {
THREAD_ERROR("Init task queue failed");
return THREAD_ERROR;
}
#endif
#endif
if (affinity_ != nullptr) {
affinity_->SetCoreId(core_list);
@ -162,10 +183,22 @@ int ActorThreadPool::CreateThreads(size_t actor_thread_num, size_t all_thread_nu
}
for (size_t i = 0; i < actor_thread_num_; ++i) {
std::lock_guard<std::mutex> _l(pool_mutex_);
auto worker = new (std::nothrow) ActorWorker();
auto worker = new (std::nothrow) ActorWorker(this);
THREAD_ERROR_IF_NULL(worker);
#ifdef OPERATOR_PARALLELISM
auto task_messages = reinterpret_cast<TaskMessage *>(malloc(sizeof(TaskMessage) * all_thread_num));
if (task_messages == nullptr) {
delete worker;
THREAD_ERROR("malloc TaskMessages failed.");
return THREAD_ERROR;
}
for (size_t j = 0; j < all_thread_num; j++) {
task_messages[j].task_id = j;
}
worker->SetTaskMessages(task_messages);
#endif
worker->InitWorkerMask(core_list, workers_.size());
worker->CreateThread(this);
worker->CreateThread();
workers_.push_back(worker);
THREAD_INFO("create actor thread[%zu]", i);
}

View File

@ -31,14 +31,13 @@ namespace mindspore {
class ActorThreadPool;
class ActorWorker : public Worker {
public:
void CreateThread(ActorThreadPool *pool);
explicit ActorWorker(ThreadPool *pool) : Worker(pool) {}
void CreateThread() override;
bool ActorActive();
private:
void RunWithSpin();
bool RunQueueActorTask();
ActorThreadPool *pool_{nullptr};
};
class ActorThreadPool : public ThreadPool {
@ -58,7 +57,6 @@ class ActorThreadPool : public ThreadPool {
private:
ActorThreadPool() {}
int CreateThreads(size_t actor_thread_num, size_t all_thread_num, const std::vector<int> &core_list);
size_t actor_thread_num_{0};
std::mutex actor_mutex_;
std::condition_variable actor_cond_;

View File

@ -46,7 +46,12 @@ class HQueue {
HQueue() {}
virtual ~HQueue() {}
bool IsInit() { return nodes.size() != 0; }
bool Init(int32_t sz) {
if (IsInit() || sz <= 0) {
return false;
}
for (int32_t i = 0; i < sz; i++) {
auto node = new HQNode<T>();
if (node == nullptr) {
@ -63,6 +68,8 @@ class HQueue {
qhead = {0, 0};
qtail = {0, 0};
nodes[0]->free = false;
queue_size = sz;
free_index = 1;
return true;
}
@ -75,17 +82,30 @@ class HQueue {
bool Enqueue(T *t) {
HQNode<T> *node = nullptr;
int32_t nodeIdx;
for (nodeIdx = 0; nodeIdx < static_cast<int32_t>(nodes.size()); nodeIdx++) {
int32_t nodeIdx = free_index;
for (; nodeIdx < queue_size; ++nodeIdx) {
bool expected = true;
if (nodes[nodeIdx]->free.compare_exchange_strong(expected, false)) {
node = nodes[nodeIdx];
free_index = nodeIdx + 1;
break;
}
}
if (node == nullptr) {
return false;
free_index = 1;
for (nodeIdx = 1; nodeIdx < queue_size; ++nodeIdx) {
bool expected = true;
if (nodes[nodeIdx]->free.compare_exchange_strong(expected, false)) {
node = nodes[nodeIdx];
free_index = nodeIdx + 1;
break;
}
}
if (node == nullptr) {
return false;
}
}
node->value = t;
node->next = {-1, 0};
@ -166,6 +186,8 @@ class HQueue {
std::atomic<Pointer> qhead;
std::atomic<Pointer> qtail;
std::vector<HQNode<T> *> nodes;
int32_t queue_size;
std::atomic<int32_t> free_index;
};
} // namespace mindspore

View File

@ -26,9 +26,14 @@ namespace mindspore {
{ printf("[INFO] %s|%d: " #content "\r\n", __func__, __LINE__, ##args); }
#define THREAD_ERROR(content, args...) \
{ printf("[ERROR] %s|%d: " #content "\r\n", __func__, __LINE__, ##args); }
#define THREAD_TEST_TRUE(flag) \
if (ptr) { \
printf("[ERROR] %s|%d: " #flag "\r\n", __func__, __LINE__); \
}
#else
#define THREAD_DEBUG(content, ...)
#define THREAD_INFO(content, ...)
#define THREAD_TEST_TRUE(flag)
#if defined(__ANDROID__)
#include <android/log.h>
#define THREAD_ERROR(content, args...) \

View File

@ -30,6 +30,13 @@ Worker::~Worker() {
if (thread_.joinable()) {
thread_.join();
}
pool_ = nullptr;
#ifdef OPERATOR_PARALLELISM
if (task_messages_ != nullptr) {
free(task_messages_);
task_messages_ = nullptr;
}
#endif
}
void Worker::CreateThread() { thread_ = std::thread(&Worker::Run, this); }
@ -93,9 +100,16 @@ void Worker::Run() {
} else {
YieldAndDeactive();
}
#ifdef OPERATOR_PARALLELISM
if (RunQueueWorkTask()) {
if (spin_count_ > 0) {
spin_count_ = 1;
}
}
#endif
if (spin_count_ > max_spin_count_) {
WaitUntilActive();
spin_count_ = 0;
spin_count_ = 1;
}
}
}
@ -115,6 +129,7 @@ bool Worker::RunLocalKernelTask() {
void Worker::YieldAndDeactive() {
// deactivate this worker only on the first entry
if (spin_count_ == 0) {
THREAD_TEST_TRUE(task_ == nullptr);
status_.store(kThreadIdle);
}
spin_count_++;
@ -135,6 +150,7 @@ void Worker::set_scale(float lhs_scale, float rhs_scale) {
void Worker::Active(Task *task, int task_id) {
{
std::lock_guard<std::mutex> _l(mutex_);
THREAD_TEST_TRUE(task_ == nullptr);
task_id_.store(task_id, std::memory_order_relaxed);
task_.store(task, std::memory_order_release);
status_ = kThreadBusy;
@ -156,6 +172,14 @@ bool Worker::available() {
return status_.compare_exchange_strong(expected, kThreadHeld);
}
bool Worker::check_task_nullptr() {
std::lock_guard<std::mutex> _l(mutex_);
if (status_ == kThreadBusy && task_ == nullptr) {
return true;
}
return false;
}
ThreadPool::~ThreadPool() {
for (auto &worker : workers_) {
delete worker;
@ -167,10 +191,23 @@ ThreadPool::~ThreadPool() {
delete affinity_;
affinity_ = nullptr;
}
#ifdef OPERATOR_PARALLELISM
#ifdef USE_HQUEUE
task_queue_.Clean();
#endif
#endif
THREAD_INFO("destruct success");
}
int ThreadPool::CreateThreads(size_t thread_num, const std::vector<int> &core_list) {
#ifdef OPERATOR_PARALLELISM
#ifdef USE_HQUEUE
if ((!task_queue_.IsInit()) && task_queue_.Init(MAX_READY_TASK_NR) != true) {
THREAD_ERROR("Init task queue failed");
return THREAD_ERROR;
}
#endif
#endif
size_t core_num = std::thread::hardware_concurrency();
thread_num = thread_num < core_num ? thread_num : core_num;
THREAD_INFO("ThreadInfo, Num: [%zu], CoreNum: [%zu]", thread_num, core_num);
@ -180,7 +217,7 @@ int ThreadPool::CreateThreads(size_t thread_num, const std::vector<int> &core_li
}
std::lock_guard<std::mutex> _l(pool_mutex_);
for (size_t i = 0; i < thread_num; ++i) {
auto worker = new (std::nothrow) Worker();
auto worker = new (std::nothrow) Worker(this);
THREAD_ERROR_IF_NULL(worker);
worker->InitWorkerMask(core_list, workers_.size());
worker->CreateThread();
@ -206,12 +243,23 @@ int ThreadPool::ParallelLaunch(const Func &func, Content content, int task_num)
// if the task num is greater than the KernelThread num
THREAD_DEBUG("launch: %d", task_num);
Task task = {func, content};
DistributeTask(&task, task_num);
Worker *curr = CurrentWorker();
DistributeTask(&task, task_num, curr);
// synchronization
// wait until the finished is equal to task_num
if (curr != nullptr) {
if (curr->RunLocalKernelTask()) {
curr->set_task_free(true);
}
}
while (task.finished != task_num) {
#ifdef OPERATOR_PARALLELISM
if (RunQueueWorkTask() == false && (curr && curr->RunLocalKernelTask() == false)) {
std::this_thread::yield();
}
#else
std::this_thread::yield();
#endif
}
// check the return value of task
if (task.status != THREAD_OK) {
@ -233,41 +281,68 @@ void ThreadPool::SyncRunTask(Task *task, int start_num, int task_num) const {
}
}
void ThreadPool::DistributeTask(Task *task, int task_num) const {
Worker *curr = CurrentWorker();
// if the current thread isn't nullptr, that is the curr is a ActorThread,
// then assign (task_num - 1) tasks to workers, and run the last one by itself
int count = 0;
int num_assigned = curr != nullptr ? task_num - 1 : task_num;
void ThreadPool::DistributeTask(Task *task, int task_num, Worker *curr) const {
int sum_frequency = 0;
std::vector<Worker *> assigned;
int num = static_cast<int>(workers_.size()) - 1;
// if the current thread isn't nullptr, that is the curr is a ActorThread,
// then assign (task_num - 1) tasks to workers, and run the last one by itself
#ifdef OPERATOR_PARALLELISM
int num_assigned = task_num;
bool use_curr = curr != nullptr;
int count = use_curr ? 1 : 0;
int offset = static_cast<int>(actor_thread_num_);
#else
int num_assigned = curr != nullptr ? task_num - 1 : task_num;
int count = 0;
int offset = 0;
bool use_curr = false;
if (curr != nullptr) {
use_curr = curr->get_task_free();
}
#endif
if (!occupied_actor_thread_) {
offset = static_cast<int>(actor_thread_num_);
}
for (int i = num; i >= offset && count < num_assigned; --i) {
if (workers_[i]->available()) {
assigned.push_back(workers_[i]);
sum_frequency += workers_[i]->frequency();
count++;
(void)++count;
}
}
// when there are not enough free threads,
// distribute other tasks to the master thread
if (curr != nullptr) {
if (use_curr) {
#ifdef OPERATOR_PARALLELISM
assigned.push_back(curr);
if (count < task_num) {
auto task_messeages = curr->GetTaskMessages();
if (task_messeages != nullptr) {
for (; count < task_num; ++count) {
task_messeages[count].task = task;
PushTaskToQueue(&task_messeages[count]);
}
}
}
#else
for (; count < task_num; ++count) {
assigned.push_back(curr);
sum_frequency += curr->frequency();
}
#endif
} else if (assigned.size() != static_cast<size_t>(task_num)) {
CalculateScales(assigned, sum_frequency);
ActiveWorkers(assigned, task, assigned.size(), curr);
SyncRunTask(task, assigned.size(), task_num);
return;
}
CalculateScales(assigned, sum_frequency);
ActiveWorkers(assigned, task, task_num, curr);
ActiveWorkers(assigned, task, assigned.size(), curr);
}
void ThreadPool::CalculateScales(const std::vector<Worker *> &assigned, int sum_frequency) const {

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_CORE_MINDRT_RUNTIME_THREADPOOL_H_
#define MINDSPORE_CORE_MINDRT_RUNTIME_THREADPOOL_H_
#include <queue>
#include <new>
#include <vector>
#include <unordered_map>
@ -35,7 +36,9 @@
#endif
#endif
#include "utils/visible.h"
#include "thread/hqueue.h"
#define USE_HQUEUE
namespace mindspore {
constexpr int kDefaultSpinCount = 300000;
constexpr int kMaxCount = 30000;
@ -43,6 +46,9 @@ constexpr int kDefaultKernelSpinCount = 3000;
constexpr int kMinSpinCount = 1;
constexpr int kDefaultFrequency = 1;
constexpr float kMaxScale = 1.;
#ifdef OPERATOR_PARALLELISM
constexpr size_t MAX_READY_TASK_NR = 4096;
#endif
enum ThreadStatus {
kThreadBusy = 0, // busy, the thread is running task
@ -62,13 +68,19 @@ typedef struct Task {
std::atomic_int finished{0};
std::atomic_int status{THREAD_OK}; // return status, RET_OK
} Task;
#if OPERATOR_PARALLELISM
typedef struct TaskMessage {
Task *task;
int task_id;
} TaskMessage;
#endif
class ThreadPool;
class Worker {
public:
Worker() = default;
explicit Worker(ThreadPool *pool) : pool_(pool) {}
virtual ~Worker();
// create thread and start running at the same time
void CreateThread();
virtual void CreateThread();
// assign task and then activate thread
void Active(Task *task, int task_id);
// activate thread
@ -96,6 +108,14 @@ class Worker {
void set_mask(const cpu_set_t &mask) { mask_ = mask; }
pthread_t handle() { return thread_.native_handle(); }
#endif
#ifdef OPERATOR_PARALLELISM
inline TaskMessage *GetTaskMessages() { return task_messages_; }
inline void SetTaskMessages(TaskMessage *task_messages) { task_messages_ = task_messages; }
inline bool RunQueueWorkTask() const;
#endif
bool check_task_nullptr();
inline bool get_task_free() const { return task_free_; }
inline void set_task_free(bool flag) { task_free_ = flag; }
protected:
void SetAffinity();
@ -123,6 +143,11 @@ class Worker {
int frequency_{kDefaultFrequency};
int spin_count_{0};
int max_spin_count_{kMinSpinCount};
bool task_free_{true};
ThreadPool *pool_{nullptr};
#ifdef OPERATOR_PARALLELISM
TaskMessage *task_messages_{nullptr};
#endif
};
class MS_CORE_API ThreadPool {
@ -150,6 +175,44 @@ class MS_CORE_API ThreadPool {
void SetWorkerIdMap();
const std::unordered_map<std::thread::id, size_t> &GetWorkerIdMap() const { return worker_ids_; }
float GetServerCpuFrequence() const { return server_cpu_frequence; }
#ifdef OPERATOR_PARALLELISM
inline TaskMessage *PopTaskFromQueue() const {
#ifdef USE_HQUEUE
return task_queue_.Dequeue();
#else
std::lock_guard<std::mutex> _l(task_mutex_);
if (task_queue_.empty()) {
return nullptr;
}
auto task_message = task_queue_.front();
task_queue_.pop();
return task_message;
#endif
}
inline void PushTaskToQueue(TaskMessage *task_message) const {
if (!task_message) {
return;
}
#ifdef USE_HQUEUE
while (!task_queue_.Enqueue(task_message)) {
}
#else
std::lock_guard<std::mutex> _l(task_mutex_);
task_queue_.push(task_message);
#endif
}
inline bool RunQueueWorkTask() const {
auto task_message = PopTaskFromQueue();
if (task_message == nullptr) {
return false;
}
auto task = task_message->task;
task_message->task = nullptr;
task->status |= task->func(task->content, task_message->task_id, 0, 0);
(void)++task->finished;
return true;
}
#endif
protected:
ThreadPool() = default;
@ -160,7 +223,7 @@ class MS_CORE_API ThreadPool {
void SyncRunTask(Task *task, int start_num, int task_num) const;
void DistributeTask(Task *task, int task_num) const;
void DistributeTask(Task *task, int task_num, Worker *curr) const;
void CalculateScales(const std::vector<Worker *> &workers, int sum_frequency) const;
void ActiveWorkers(const std::vector<Worker *> &workers, Task *task, int task_num, const Worker *curr) const;
@ -176,6 +239,17 @@ class MS_CORE_API ThreadPool {
int max_spin_count_{kDefaultSpinCount};
int min_spin_count_{kMinSpinCount};
float server_cpu_frequence = -1.0f; // Unit : GHz
#ifdef OPERATOR_PARALLELISM
#ifdef USE_HQUEUE
mutable HQueue<TaskMessage> task_queue_;
#else
mutable std::mutex task_mutex_;
mutable std::queue<TaskMessage *> task_queue_;
#endif
#endif
};
#ifdef OPERATOR_PARALLELISM
inline bool Worker::RunQueueWorkTask() const { return pool_->RunQueueWorkTask(); }
#endif
} // namespace mindspore
#endif // MINDSPORE_CORE_MINDRT_RUNTIME_THREADPOOL_H_

View File

@ -31,9 +31,6 @@
namespace mindspore::lite {
namespace {
#ifdef ENABLE_MINDRT
constexpr int kDefaultParallelNum = 2;
#endif
const constexpr int kMaxLiteContextDeviceNums = 2;
const constexpr int kMaxInnerContextDeviceNums = 3;
} // namespace
@ -111,11 +108,7 @@ void InnerContext::SetContextDevice(const Context *context) {
return;
}
int InnerContext::Init() {
if (RET_OK != this->IsValid()) {
MS_LOG(ERROR) << "Context is not valid";
return RET_NOT_SUPPORT;
}
int InnerContext::CreateThreadPool() {
if (this->thread_pool_ == nullptr) {
BindMode bind_mode = Power_NoBind;
if (this->IsCpuEnabled()) {
@ -123,7 +116,11 @@ int InnerContext::Init() {
}
#ifdef ENABLE_MINDRT
#ifdef OPERATOR_PARALLELISM
int actor_parallel_thread = this->enable_parallel_ ? (this->thread_num_ > 2 ? (this->thread_num_ / 2) : 1) : 1;
#else
int actor_parallel_thread = this->enable_parallel_ ? kDefaultParallelNum : 1;
#endif
if (this->affinity_core_list_.empty()) {
thread_pool_ = ActorThreadPool::CreateThreadPool(actor_parallel_thread, this->thread_num_, bind_mode);
MS_CHECK_TRUE_MSG(thread_pool_ != nullptr, RET_NULL_PTR, "Create Allocator failed");
@ -137,6 +134,18 @@ int InnerContext::Init() {
thread_pool_->SetCpuAffinity(static_cast<mindspore::BindMode>(bind_mode));
#endif
}
return RET_OK;
}
int InnerContext::Init() {
if (this->IsValid() != RET_OK) {
MS_LOG(ERROR) << "Context is not valid";
return RET_NOT_SUPPORT;
}
if (CreateThreadPool()) {
MS_LOG(ERROR) << "CreateThreadPool failed.";
return RET_ERROR;
}
if (this->allocator == nullptr) {
#ifdef SERVER_INFERENCE

View File

@ -32,6 +32,11 @@
#endif
namespace mindspore::lite {
#ifdef ENABLE_MINDRT
#ifndef OPERATOR_PARALLELISM
constexpr int kDefaultParallelNum = 2;
#endif
#endif
struct InnerContext : public Context {
public:
InnerContext() { InitDeviceFp16(); }
@ -108,6 +113,8 @@ struct InnerContext : public Context {
void InitDeviceFp16();
int CreateThreadPool();
bool device_and_pkg_support_fp16_ = false;
#ifdef SERVER_INFERENCE

View File

@ -266,11 +266,19 @@ int Scheduler::SchedulePreProcess() {
return *is_infershape_;
}
if (context_->enable_parallel_ && *is_infershape_ != RET_INFER_INVALID) {
if (context_->enable_parallel_) {
#ifndef AUTO_PARALLEL_CLIP
#ifdef OPERATOR_PARALLELISM
auto search_sub_graph =
SearchSubGraph(context_, src_model_, src_tensors_, &op_parameters_, &graph_output_node_indexes_);
search_sub_graph.SubGraphSplit();
search_sub_graph.SubGraphSplitByOperator();
#else
if (*is_infershape_ != RET_INFER_INVALID) {
auto search_sub_graph =
SearchSubGraph(context_, src_model_, src_tensors_, &op_parameters_, &graph_output_node_indexes_);
search_sub_graph.SubGraphSplit();
}
#endif
#else
MS_LOG(ERROR) << unsupport_auto_parallel_log;
return RET_NOT_SUPPORT;

View File

@ -20,6 +20,7 @@
#include <algorithm>
#include <iterator>
#include <vector>
#include <queue>
#include "src/tensor.h"
#include "schema/ops_generated.h"
#include "schema/model_generated.h"
@ -216,9 +217,11 @@ const schema::Primitive *SearchSubGraph::CreatePartialPrimitive(int64_t subgraph
}
void SearchSubGraph::ConvertSubGraphToModel(std::vector<Subgraph> *sub_graphs) {
#ifndef OPERATOR_PARALLELISM
if (sub_graphs->size() != kDefaultSubGraphSize) {
return;
}
#endif
Model::SubGraph *main_graphs = model_->sub_graphs_.front();
for (Subgraph &subgraph : *sub_graphs) {
@ -1069,4 +1072,96 @@ void SearchSubGraph::SubGraphSplit() {
}
return;
}
#ifdef OPERATOR_PARALLELISM
void SearchSubGraph::InsertNodeBegin(uint32_t index, Subgraph *subgraph, std::vector<size_t> *outputs) {
size_t last_index = index;
while (1) {
Model::Node *node = node_list_.at(index);
if (node == nullptr) {
subgraph->heads_.push_back(last_index);
return;
}
std::vector<uint32_t> input = node->input_indices_;
RemoveConstNode(&input);
/* all node_input is graph_input */
for (size_t i = 0; i < input.size(); i++) {
if (tensors_[input[i]].type_ != INPUT) {
break;
}
subgraph->heads_.push_back(last_index);
return;
}
/* split in graph */
if (IsNodeSubGraphHead(index, subgraph->nodes_)) {
if (subgraph->nodes_.empty()) {
subgraph->heads_.push_back(index);
subgraph->nodes_.insert(subgraph->nodes_.begin(), index);
node_list_.at(index) = nullptr;
for (uint32_t in : input) {
auto next_nodes = tensors_[in].out_nodes_;
std::copy(next_nodes.begin(), next_nodes.end(), std::back_inserter(*outputs));
}
return;
}
subgraph->heads_.push_back(last_index);
outputs->push_back(index);
return;
}
for (uint32_t in : input) {
auto next_nodes = tensors_[in].out_nodes_;
std::copy(next_nodes.begin(), next_nodes.end(), std::back_inserter(*outputs));
}
subgraph->nodes_.insert(subgraph->nodes_.begin(), index);
node_list_.at(index) = nullptr;
if (outputs->size() == 1) {
last_index = index;
index = outputs->at(0);
outputs->clear();
} else {
subgraph->heads_.push_back(index);
return;
}
}
return;
}
void SearchSubGraph::SubGraphSplitByOperator() {
if (!ValidInParallel()) {
return;
}
sub_graphs_.clear();
node_list_ = model_->all_nodes_;
std::queue<size_t> outputs{};
for (auto out : *output_nodes_) {
outputs.push(out);
}
std::vector<size_t> outputs_vec{};
while (!outputs.empty()) {
auto out = outputs.front();
outputs.pop();
Subgraph subgraph;
subgraph.ends_.push_back(out);
subgraph.device_ = DT_CPU;
subgraph.thread_ = context_->thread_num_ > 2 ? (context_->thread_num_ / 2) : 1;
InsertNodeBegin(static_cast<uint32_t>(out), &subgraph, &outputs_vec);
for (auto new_out : outputs_vec) {
outputs.push(new_out);
}
outputs_vec.clear();
if (!subgraph.nodes_.empty()) {
sub_graphs_.push_back(std::move(subgraph));
}
}
ConvertSubGraphToModel(&sub_graphs_);
}
#endif
} // namespace mindspore::lite

View File

@ -88,6 +88,10 @@ class SearchSubGraph {
public:
void SubGraphSplit();
#ifdef OPERATOR_PARALLELISM
void SubGraphSplitByOperator();
void InsertNodeBegin(uint32_t index, Subgraph *subgraph, std::vector<size_t> *outputs);
#endif
private: /* split by output */
void SubGraphSplitByOutput();

View File

@ -35,6 +35,10 @@ file(GLOB_RECURSE TEST_UT_SRC
${TEST_DIR}/ut/src/api/tensor_c_test.cc
)
if(MSLITE_ENABLE_SERVER_INFERENCE)
list(REMOVE_ITEM TEST_UT_SRC ${TEST_DIR}/st/mindrt_parallel_runtime_test.cc)
endif()
if(MSLITE_ENABLE_RUNTIME_CONVERT)
list(APPEND TEST_UT_SRC ${TEST_DIR}/ut/src/runtime/runtime_convert_tests.cc)
endif()
@ -115,6 +119,12 @@ if(MSLITE_ENABLE_CONVERTER)
${TEST_DIR}/ut/src/dynamic_library_loader_test.cc
${TEST_DIR}/ut/tools/optimizer/fusion/*.cc
)
if(MSLITE_ENABLE_SERVER_INFERENCE)
list(REMOVE_ITEM TEST_CONVERTER_UT_SRC ${TEST_DIR}/st/mindrt_parallel_test.cc)
list(REMOVE_ITEM TEST_UT_SRC ${TEST_DIR}/st/benchmark_test.cc)
list(REMOVE_ITEM TEST_CONVERTER_UT_SRC ${TEST_DIR}/st/graph_test.cc)
list(REMOVE_ITEM TEST_CONVERTER_UT_SRC ${TEST_DIR}/st/sub_graph_test.cc)
endif()
list(APPEND TEST_UT_SRC ${TEST_CONVERTER_UT_SRC})
set(TEST_LITE_SRC