diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index f976046a669..acefbeac65a 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -3,6 +3,7 @@ # Scene2: # file_path:function_name1, function_name2 # +mindspore/mindspore/core/mindrt/src/thread/actor_threadpool.cc:mindspore::ActorWorker::RunWithSpin mindspore/mindspore/lite/src/ops/primitive_c.cc:mindspore::lite::PrimitiveC::Create mindspore/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc:mindspore::dataset::CsvOp::CsvParser::InitCsvParser mindspore/mindspore/lite/tools/converter/graphdef_transform.cc:mindspore::lite::GraphDefTransform::Transform diff --git a/mindspore/core/mindrt/src/thread/actor_threadpool.cc b/mindspore/core/mindrt/src/thread/actor_threadpool.cc index d91bcbf86ed..bfd2850e4bb 100644 --- a/mindspore/core/mindrt/src/thread/actor_threadpool.cc +++ b/mindspore/core/mindrt/src/thread/actor_threadpool.cc @@ -22,11 +22,7 @@ namespace mindspore { constexpr size_t MAX_READY_ACTOR_NR = 4096; -void ActorWorker::CreateThread(ActorThreadPool *pool) { - THREAD_RETURN_IF_NULL(pool); - pool_ = pool; - thread_ = std::thread(&ActorWorker::RunWithSpin, this); -} +void ActorWorker::CreateThread() { thread_ = std::thread(&ActorWorker::RunWithSpin, this); } void ActorWorker::RunWithSpin() { SetAffinity(); @@ -41,24 +37,43 @@ void ActorWorker::RunWithSpin() { #endif while (alive_) { // only run either local KernelTask or PoolQueue ActorTask - if (RunLocalKernelTask() || RunQueueActorTask()) { + if (RunLocalKernelTask()) { spin_count_ = 0; } else { YieldAndDeactive(); } +#ifdef OPERATOR_PARALLELISM + if (RunQueueActorTask() || RunQueueWorkTask()) { +#else + if (RunQueueActorTask()) { +#endif + if (spin_count_ > 0) { + spin_count_ = 1; + } + } if (spin_count_ > max_spin_count_) { WaitUntilActive(); - spin_count_ = 0; + spin_count_ = 1; } } } bool ActorWorker::RunQueueActorTask() { - THREAD_ERROR_IF_NULL(pool_); - auto actor = pool_->PopActorFromQueue(); + if (pool_ == nullptr) { + return false; + } + auto actor = reinterpret_cast(pool_)->PopActorFromQueue(); if (actor == nullptr) { return false; } +#ifndef OPERATOR_PARALLELISM + if (available() || check_task_nullptr()) { + status_ = kThreadBusy; + set_task_free(true); + } else { + set_task_free(false); + } +#endif actor->Run(); return true; } @@ -149,6 +164,12 @@ int ActorThreadPool::CreateThreads(size_t actor_thread_num, size_t all_thread_nu THREAD_ERROR("init actor queue failed."); return THREAD_ERROR; } +#ifdef OPERATOR_PARALLELISM + if (task_queue_.Init(MAX_READY_TASK_NR) != true) { + THREAD_ERROR("Init task queue failed"); + return THREAD_ERROR; + } +#endif #endif if (affinity_ != nullptr) { affinity_->SetCoreId(core_list); @@ -162,10 +183,22 @@ int ActorThreadPool::CreateThreads(size_t actor_thread_num, size_t all_thread_nu } for (size_t i = 0; i < actor_thread_num_; ++i) { std::lock_guard _l(pool_mutex_); - auto worker = new (std::nothrow) ActorWorker(); + auto worker = new (std::nothrow) ActorWorker(this); THREAD_ERROR_IF_NULL(worker); +#ifdef OPERATOR_PARALLELISM + auto task_messages = reinterpret_cast(malloc(sizeof(TaskMessage) * all_thread_num)); + if (task_messages == nullptr) { + delete worker; + THREAD_ERROR("malloc TaskMessages failed."); + return THREAD_ERROR; + } + for (size_t j = 0; j < all_thread_num; j++) { + task_messages[j].task_id = j; + } + worker->SetTaskMessages(task_messages); +#endif worker->InitWorkerMask(core_list, workers_.size()); - worker->CreateThread(this); + worker->CreateThread(); workers_.push_back(worker); THREAD_INFO("create actor thread[%zu]", i); } diff --git a/mindspore/core/mindrt/src/thread/actor_threadpool.h b/mindspore/core/mindrt/src/thread/actor_threadpool.h index 12a73816442..318c7d670ef 100644 --- a/mindspore/core/mindrt/src/thread/actor_threadpool.h +++ b/mindspore/core/mindrt/src/thread/actor_threadpool.h @@ -31,14 +31,13 @@ namespace mindspore { class ActorThreadPool; class ActorWorker : public Worker { public: - void CreateThread(ActorThreadPool *pool); + explicit ActorWorker(ThreadPool *pool) : Worker(pool) {} + void CreateThread() override; bool ActorActive(); private: void RunWithSpin(); bool RunQueueActorTask(); - - ActorThreadPool *pool_{nullptr}; }; class ActorThreadPool : public ThreadPool { @@ -58,7 +57,6 @@ class ActorThreadPool : public ThreadPool { private: ActorThreadPool() {} int CreateThreads(size_t actor_thread_num, size_t all_thread_num, const std::vector &core_list); - size_t actor_thread_num_{0}; std::mutex actor_mutex_; std::condition_variable actor_cond_; diff --git a/mindspore/core/mindrt/src/thread/hqueue.h b/mindspore/core/mindrt/src/thread/hqueue.h index 744bf99643d..03e4623200b 100644 --- a/mindspore/core/mindrt/src/thread/hqueue.h +++ b/mindspore/core/mindrt/src/thread/hqueue.h @@ -46,7 +46,12 @@ class HQueue { HQueue() {} virtual ~HQueue() {} + bool IsInit() { return nodes.size() != 0; } + bool Init(int32_t sz) { + if (IsInit() || sz <= 0) { + return false; + } for (int32_t i = 0; i < sz; i++) { auto node = new HQNode(); if (node == nullptr) { @@ -63,6 +68,8 @@ class HQueue { qhead = {0, 0}; qtail = {0, 0}; nodes[0]->free = false; + queue_size = sz; + free_index = 1; return true; } @@ -75,17 +82,30 @@ class HQueue { bool Enqueue(T *t) { HQNode *node = nullptr; - int32_t nodeIdx; - for (nodeIdx = 0; nodeIdx < static_cast(nodes.size()); nodeIdx++) { + int32_t nodeIdx = free_index; + for (; nodeIdx < queue_size; ++nodeIdx) { bool expected = true; if (nodes[nodeIdx]->free.compare_exchange_strong(expected, false)) { node = nodes[nodeIdx]; + free_index = nodeIdx + 1; break; } } if (node == nullptr) { - return false; + free_index = 1; + for (nodeIdx = 1; nodeIdx < queue_size; ++nodeIdx) { + bool expected = true; + if (nodes[nodeIdx]->free.compare_exchange_strong(expected, false)) { + node = nodes[nodeIdx]; + free_index = nodeIdx + 1; + break; + } + } + if (node == nullptr) { + return false; + } } + node->value = t; node->next = {-1, 0}; @@ -166,6 +186,8 @@ class HQueue { std::atomic qhead; std::atomic qtail; std::vector *> nodes; + int32_t queue_size; + std::atomic free_index; }; } // namespace mindspore diff --git a/mindspore/core/mindrt/src/thread/threadlog.h b/mindspore/core/mindrt/src/thread/threadlog.h index 8594d852daa..2b95cc3be17 100644 --- a/mindspore/core/mindrt/src/thread/threadlog.h +++ b/mindspore/core/mindrt/src/thread/threadlog.h @@ -26,9 +26,14 @@ namespace mindspore { { printf("[INFO] %s|%d: " #content "\r\n", __func__, __LINE__, ##args); } #define THREAD_ERROR(content, args...) \ { printf("[ERROR] %s|%d: " #content "\r\n", __func__, __LINE__, ##args); } +#define THREAD_TEST_TRUE(flag) \ + if (ptr) { \ + printf("[ERROR] %s|%d: " #flag "\r\n", __func__, __LINE__); \ + } #else #define THREAD_DEBUG(content, ...) #define THREAD_INFO(content, ...) +#define THREAD_TEST_TRUE(flag) #if defined(__ANDROID__) #include #define THREAD_ERROR(content, args...) \ diff --git a/mindspore/core/mindrt/src/thread/threadpool.cc b/mindspore/core/mindrt/src/thread/threadpool.cc index a8c7610031c..58b8bc19f2a 100644 --- a/mindspore/core/mindrt/src/thread/threadpool.cc +++ b/mindspore/core/mindrt/src/thread/threadpool.cc @@ -30,6 +30,13 @@ Worker::~Worker() { if (thread_.joinable()) { thread_.join(); } + pool_ = nullptr; +#ifdef OPERATOR_PARALLELISM + if (task_messages_ != nullptr) { + free(task_messages_); + task_messages_ = nullptr; + } +#endif } void Worker::CreateThread() { thread_ = std::thread(&Worker::Run, this); } @@ -93,9 +100,16 @@ void Worker::Run() { } else { YieldAndDeactive(); } +#ifdef OPERATOR_PARALLELISM + if (RunQueueWorkTask()) { + if (spin_count_ > 0) { + spin_count_ = 1; + } + } +#endif if (spin_count_ > max_spin_count_) { WaitUntilActive(); - spin_count_ = 0; + spin_count_ = 1; } } } @@ -115,6 +129,7 @@ bool Worker::RunLocalKernelTask() { void Worker::YieldAndDeactive() { // deactivate this worker only on the first entry if (spin_count_ == 0) { + THREAD_TEST_TRUE(task_ == nullptr); status_.store(kThreadIdle); } spin_count_++; @@ -135,6 +150,7 @@ void Worker::set_scale(float lhs_scale, float rhs_scale) { void Worker::Active(Task *task, int task_id) { { std::lock_guard _l(mutex_); + THREAD_TEST_TRUE(task_ == nullptr); task_id_.store(task_id, std::memory_order_relaxed); task_.store(task, std::memory_order_release); status_ = kThreadBusy; @@ -156,6 +172,14 @@ bool Worker::available() { return status_.compare_exchange_strong(expected, kThreadHeld); } +bool Worker::check_task_nullptr() { + std::lock_guard _l(mutex_); + if (status_ == kThreadBusy && task_ == nullptr) { + return true; + } + return false; +} + ThreadPool::~ThreadPool() { for (auto &worker : workers_) { delete worker; @@ -167,10 +191,23 @@ ThreadPool::~ThreadPool() { delete affinity_; affinity_ = nullptr; } +#ifdef OPERATOR_PARALLELISM +#ifdef USE_HQUEUE + task_queue_.Clean(); +#endif +#endif THREAD_INFO("destruct success"); } int ThreadPool::CreateThreads(size_t thread_num, const std::vector &core_list) { +#ifdef OPERATOR_PARALLELISM +#ifdef USE_HQUEUE + if ((!task_queue_.IsInit()) && task_queue_.Init(MAX_READY_TASK_NR) != true) { + THREAD_ERROR("Init task queue failed"); + return THREAD_ERROR; + } +#endif +#endif size_t core_num = std::thread::hardware_concurrency(); thread_num = thread_num < core_num ? thread_num : core_num; THREAD_INFO("ThreadInfo, Num: [%zu], CoreNum: [%zu]", thread_num, core_num); @@ -180,7 +217,7 @@ int ThreadPool::CreateThreads(size_t thread_num, const std::vector &core_li } std::lock_guard _l(pool_mutex_); for (size_t i = 0; i < thread_num; ++i) { - auto worker = new (std::nothrow) Worker(); + auto worker = new (std::nothrow) Worker(this); THREAD_ERROR_IF_NULL(worker); worker->InitWorkerMask(core_list, workers_.size()); worker->CreateThread(); @@ -206,12 +243,23 @@ int ThreadPool::ParallelLaunch(const Func &func, Content content, int task_num) // if the task num is greater than the KernelThread num THREAD_DEBUG("launch: %d", task_num); Task task = {func, content}; - - DistributeTask(&task, task_num); + Worker *curr = CurrentWorker(); + DistributeTask(&task, task_num, curr); // synchronization // wait until the finished is equal to task_num + if (curr != nullptr) { + if (curr->RunLocalKernelTask()) { + curr->set_task_free(true); + } + } while (task.finished != task_num) { +#ifdef OPERATOR_PARALLELISM + if (RunQueueWorkTask() == false && (curr && curr->RunLocalKernelTask() == false)) { + std::this_thread::yield(); + } +#else std::this_thread::yield(); +#endif } // check the return value of task if (task.status != THREAD_OK) { @@ -233,41 +281,68 @@ void ThreadPool::SyncRunTask(Task *task, int start_num, int task_num) const { } } -void ThreadPool::DistributeTask(Task *task, int task_num) const { - Worker *curr = CurrentWorker(); - // if the current thread isn't nullptr, that is the curr is a ActorThread, - // then assign (task_num - 1) tasks to workers, and run the last one by itself - int count = 0; - int num_assigned = curr != nullptr ? task_num - 1 : task_num; +void ThreadPool::DistributeTask(Task *task, int task_num, Worker *curr) const { int sum_frequency = 0; std::vector assigned; int num = static_cast(workers_.size()) - 1; + // if the current thread isn't nullptr, that is the curr is a ActorThread, + // then assign (task_num - 1) tasks to workers, and run the last one by itself +#ifdef OPERATOR_PARALLELISM + int num_assigned = task_num; + bool use_curr = curr != nullptr; + int count = use_curr ? 1 : 0; + int offset = static_cast(actor_thread_num_); +#else + int num_assigned = curr != nullptr ? task_num - 1 : task_num; + int count = 0; int offset = 0; + bool use_curr = false; + + if (curr != nullptr) { + use_curr = curr->get_task_free(); + } +#endif if (!occupied_actor_thread_) { offset = static_cast(actor_thread_num_); } + for (int i = num; i >= offset && count < num_assigned; --i) { if (workers_[i]->available()) { assigned.push_back(workers_[i]); sum_frequency += workers_[i]->frequency(); - count++; + (void)++count; } } + // when there are not enough free threads, // distribute other tasks to the master thread - if (curr != nullptr) { + if (use_curr) { +#ifdef OPERATOR_PARALLELISM + assigned.push_back(curr); + if (count < task_num) { + auto task_messeages = curr->GetTaskMessages(); + if (task_messeages != nullptr) { + for (; count < task_num; ++count) { + task_messeages[count].task = task; + PushTaskToQueue(&task_messeages[count]); + } + } + } +#else for (; count < task_num; ++count) { assigned.push_back(curr); sum_frequency += curr->frequency(); } +#endif } else if (assigned.size() != static_cast(task_num)) { CalculateScales(assigned, sum_frequency); ActiveWorkers(assigned, task, assigned.size(), curr); SyncRunTask(task, assigned.size(), task_num); return; } + CalculateScales(assigned, sum_frequency); - ActiveWorkers(assigned, task, task_num, curr); + ActiveWorkers(assigned, task, assigned.size(), curr); } void ThreadPool::CalculateScales(const std::vector &assigned, int sum_frequency) const { diff --git a/mindspore/core/mindrt/src/thread/threadpool.h b/mindspore/core/mindrt/src/thread/threadpool.h index f68cc039883..7b1df02cb39 100644 --- a/mindspore/core/mindrt/src/thread/threadpool.h +++ b/mindspore/core/mindrt/src/thread/threadpool.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_CORE_MINDRT_RUNTIME_THREADPOOL_H_ #define MINDSPORE_CORE_MINDRT_RUNTIME_THREADPOOL_H_ +#include #include #include #include @@ -35,7 +36,9 @@ #endif #endif #include "utils/visible.h" +#include "thread/hqueue.h" +#define USE_HQUEUE namespace mindspore { constexpr int kDefaultSpinCount = 300000; constexpr int kMaxCount = 30000; @@ -43,6 +46,9 @@ constexpr int kDefaultKernelSpinCount = 3000; constexpr int kMinSpinCount = 1; constexpr int kDefaultFrequency = 1; constexpr float kMaxScale = 1.; +#ifdef OPERATOR_PARALLELISM +constexpr size_t MAX_READY_TASK_NR = 4096; +#endif enum ThreadStatus { kThreadBusy = 0, // busy, the thread is running task @@ -62,13 +68,19 @@ typedef struct Task { std::atomic_int finished{0}; std::atomic_int status{THREAD_OK}; // return status, RET_OK } Task; - +#if OPERATOR_PARALLELISM +typedef struct TaskMessage { + Task *task; + int task_id; +} TaskMessage; +#endif +class ThreadPool; class Worker { public: - Worker() = default; + explicit Worker(ThreadPool *pool) : pool_(pool) {} virtual ~Worker(); // create thread and start running at the same time - void CreateThread(); + virtual void CreateThread(); // assign task and then activate thread void Active(Task *task, int task_id); // activate thread @@ -96,6 +108,14 @@ class Worker { void set_mask(const cpu_set_t &mask) { mask_ = mask; } pthread_t handle() { return thread_.native_handle(); } #endif +#ifdef OPERATOR_PARALLELISM + inline TaskMessage *GetTaskMessages() { return task_messages_; } + inline void SetTaskMessages(TaskMessage *task_messages) { task_messages_ = task_messages; } + inline bool RunQueueWorkTask() const; +#endif + bool check_task_nullptr(); + inline bool get_task_free() const { return task_free_; } + inline void set_task_free(bool flag) { task_free_ = flag; } protected: void SetAffinity(); @@ -123,6 +143,11 @@ class Worker { int frequency_{kDefaultFrequency}; int spin_count_{0}; int max_spin_count_{kMinSpinCount}; + bool task_free_{true}; + ThreadPool *pool_{nullptr}; +#ifdef OPERATOR_PARALLELISM + TaskMessage *task_messages_{nullptr}; +#endif }; class MS_CORE_API ThreadPool { @@ -150,6 +175,44 @@ class MS_CORE_API ThreadPool { void SetWorkerIdMap(); const std::unordered_map &GetWorkerIdMap() const { return worker_ids_; } float GetServerCpuFrequence() const { return server_cpu_frequence; } +#ifdef OPERATOR_PARALLELISM + inline TaskMessage *PopTaskFromQueue() const { +#ifdef USE_HQUEUE + return task_queue_.Dequeue(); +#else + std::lock_guard _l(task_mutex_); + if (task_queue_.empty()) { + return nullptr; + } + auto task_message = task_queue_.front(); + task_queue_.pop(); + return task_message; +#endif + } + inline void PushTaskToQueue(TaskMessage *task_message) const { + if (!task_message) { + return; + } +#ifdef USE_HQUEUE + while (!task_queue_.Enqueue(task_message)) { + } +#else + std::lock_guard _l(task_mutex_); + task_queue_.push(task_message); +#endif + } + inline bool RunQueueWorkTask() const { + auto task_message = PopTaskFromQueue(); + if (task_message == nullptr) { + return false; + } + auto task = task_message->task; + task_message->task = nullptr; + task->status |= task->func(task->content, task_message->task_id, 0, 0); + (void)++task->finished; + return true; + } +#endif protected: ThreadPool() = default; @@ -160,7 +223,7 @@ class MS_CORE_API ThreadPool { void SyncRunTask(Task *task, int start_num, int task_num) const; - void DistributeTask(Task *task, int task_num) const; + void DistributeTask(Task *task, int task_num, Worker *curr) const; void CalculateScales(const std::vector &workers, int sum_frequency) const; void ActiveWorkers(const std::vector &workers, Task *task, int task_num, const Worker *curr) const; @@ -176,6 +239,17 @@ class MS_CORE_API ThreadPool { int max_spin_count_{kDefaultSpinCount}; int min_spin_count_{kMinSpinCount}; float server_cpu_frequence = -1.0f; // Unit : GHz +#ifdef OPERATOR_PARALLELISM +#ifdef USE_HQUEUE + mutable HQueue task_queue_; +#else + mutable std::mutex task_mutex_; + mutable std::queue task_queue_; +#endif +#endif }; +#ifdef OPERATOR_PARALLELISM +inline bool Worker::RunQueueWorkTask() const { return pool_->RunQueueWorkTask(); } +#endif } // namespace mindspore #endif // MINDSPORE_CORE_MINDRT_RUNTIME_THREADPOOL_H_ diff --git a/mindspore/lite/src/inner_context.cc b/mindspore/lite/src/inner_context.cc index bddc03bf874..9562503d112 100644 --- a/mindspore/lite/src/inner_context.cc +++ b/mindspore/lite/src/inner_context.cc @@ -31,9 +31,6 @@ namespace mindspore::lite { namespace { -#ifdef ENABLE_MINDRT -constexpr int kDefaultParallelNum = 2; -#endif const constexpr int kMaxLiteContextDeviceNums = 2; const constexpr int kMaxInnerContextDeviceNums = 3; } // namespace @@ -111,11 +108,7 @@ void InnerContext::SetContextDevice(const Context *context) { return; } -int InnerContext::Init() { - if (RET_OK != this->IsValid()) { - MS_LOG(ERROR) << "Context is not valid"; - return RET_NOT_SUPPORT; - } +int InnerContext::CreateThreadPool() { if (this->thread_pool_ == nullptr) { BindMode bind_mode = Power_NoBind; if (this->IsCpuEnabled()) { @@ -123,7 +116,11 @@ int InnerContext::Init() { } #ifdef ENABLE_MINDRT +#ifdef OPERATOR_PARALLELISM + int actor_parallel_thread = this->enable_parallel_ ? (this->thread_num_ > 2 ? (this->thread_num_ / 2) : 1) : 1; +#else int actor_parallel_thread = this->enable_parallel_ ? kDefaultParallelNum : 1; +#endif if (this->affinity_core_list_.empty()) { thread_pool_ = ActorThreadPool::CreateThreadPool(actor_parallel_thread, this->thread_num_, bind_mode); MS_CHECK_TRUE_MSG(thread_pool_ != nullptr, RET_NULL_PTR, "Create Allocator failed"); @@ -137,6 +134,18 @@ int InnerContext::Init() { thread_pool_->SetCpuAffinity(static_cast(bind_mode)); #endif } + return RET_OK; +} +int InnerContext::Init() { + if (this->IsValid() != RET_OK) { + MS_LOG(ERROR) << "Context is not valid"; + return RET_NOT_SUPPORT; + } + + if (CreateThreadPool()) { + MS_LOG(ERROR) << "CreateThreadPool failed."; + return RET_ERROR; + } if (this->allocator == nullptr) { #ifdef SERVER_INFERENCE diff --git a/mindspore/lite/src/inner_context.h b/mindspore/lite/src/inner_context.h index 11c697db7d0..3d08c864066 100644 --- a/mindspore/lite/src/inner_context.h +++ b/mindspore/lite/src/inner_context.h @@ -32,6 +32,11 @@ #endif namespace mindspore::lite { +#ifdef ENABLE_MINDRT +#ifndef OPERATOR_PARALLELISM +constexpr int kDefaultParallelNum = 2; +#endif +#endif struct InnerContext : public Context { public: InnerContext() { InitDeviceFp16(); } @@ -108,6 +113,8 @@ struct InnerContext : public Context { void InitDeviceFp16(); + int CreateThreadPool(); + bool device_and_pkg_support_fp16_ = false; #ifdef SERVER_INFERENCE diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index c3ba959dd7d..a917301a134 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -266,11 +266,19 @@ int Scheduler::SchedulePreProcess() { return *is_infershape_; } - if (context_->enable_parallel_ && *is_infershape_ != RET_INFER_INVALID) { + if (context_->enable_parallel_) { #ifndef AUTO_PARALLEL_CLIP +#ifdef OPERATOR_PARALLELISM auto search_sub_graph = SearchSubGraph(context_, src_model_, src_tensors_, &op_parameters_, &graph_output_node_indexes_); - search_sub_graph.SubGraphSplit(); + search_sub_graph.SubGraphSplitByOperator(); +#else + if (*is_infershape_ != RET_INFER_INVALID) { + auto search_sub_graph = + SearchSubGraph(context_, src_model_, src_tensors_, &op_parameters_, &graph_output_node_indexes_); + search_sub_graph.SubGraphSplit(); + } +#endif #else MS_LOG(ERROR) << unsupport_auto_parallel_log; return RET_NOT_SUPPORT; diff --git a/mindspore/lite/src/sub_graph_split.cc b/mindspore/lite/src/sub_graph_split.cc index 29cc28b2787..50bc506f5f3 100644 --- a/mindspore/lite/src/sub_graph_split.cc +++ b/mindspore/lite/src/sub_graph_split.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include "src/tensor.h" #include "schema/ops_generated.h" #include "schema/model_generated.h" @@ -216,9 +217,11 @@ const schema::Primitive *SearchSubGraph::CreatePartialPrimitive(int64_t subgraph } void SearchSubGraph::ConvertSubGraphToModel(std::vector *sub_graphs) { +#ifndef OPERATOR_PARALLELISM if (sub_graphs->size() != kDefaultSubGraphSize) { return; } +#endif Model::SubGraph *main_graphs = model_->sub_graphs_.front(); for (Subgraph &subgraph : *sub_graphs) { @@ -1069,4 +1072,96 @@ void SearchSubGraph::SubGraphSplit() { } return; } + +#ifdef OPERATOR_PARALLELISM +void SearchSubGraph::InsertNodeBegin(uint32_t index, Subgraph *subgraph, std::vector *outputs) { + size_t last_index = index; + + while (1) { + Model::Node *node = node_list_.at(index); + if (node == nullptr) { + subgraph->heads_.push_back(last_index); + return; + } + + std::vector input = node->input_indices_; + RemoveConstNode(&input); + + /* all node_input is graph_input */ + for (size_t i = 0; i < input.size(); i++) { + if (tensors_[input[i]].type_ != INPUT) { + break; + } + subgraph->heads_.push_back(last_index); + return; + } + + /* split in graph */ + if (IsNodeSubGraphHead(index, subgraph->nodes_)) { + if (subgraph->nodes_.empty()) { + subgraph->heads_.push_back(index); + subgraph->nodes_.insert(subgraph->nodes_.begin(), index); + node_list_.at(index) = nullptr; + for (uint32_t in : input) { + auto next_nodes = tensors_[in].out_nodes_; + std::copy(next_nodes.begin(), next_nodes.end(), std::back_inserter(*outputs)); + } + return; + } + subgraph->heads_.push_back(last_index); + outputs->push_back(index); + return; + } + + for (uint32_t in : input) { + auto next_nodes = tensors_[in].out_nodes_; + std::copy(next_nodes.begin(), next_nodes.end(), std::back_inserter(*outputs)); + } + subgraph->nodes_.insert(subgraph->nodes_.begin(), index); + node_list_.at(index) = nullptr; + if (outputs->size() == 1) { + last_index = index; + index = outputs->at(0); + outputs->clear(); + } else { + subgraph->heads_.push_back(index); + return; + } + } + + return; +} + +void SearchSubGraph::SubGraphSplitByOperator() { + if (!ValidInParallel()) { + return; + } + sub_graphs_.clear(); + node_list_ = model_->all_nodes_; + std::queue outputs{}; + for (auto out : *output_nodes_) { + outputs.push(out); + } + std::vector outputs_vec{}; + while (!outputs.empty()) { + auto out = outputs.front(); + outputs.pop(); + + Subgraph subgraph; + subgraph.ends_.push_back(out); + subgraph.device_ = DT_CPU; + subgraph.thread_ = context_->thread_num_ > 2 ? (context_->thread_num_ / 2) : 1; + + InsertNodeBegin(static_cast(out), &subgraph, &outputs_vec); + for (auto new_out : outputs_vec) { + outputs.push(new_out); + } + outputs_vec.clear(); + if (!subgraph.nodes_.empty()) { + sub_graphs_.push_back(std::move(subgraph)); + } + } + ConvertSubGraphToModel(&sub_graphs_); +} +#endif } // namespace mindspore::lite diff --git a/mindspore/lite/src/sub_graph_split.h b/mindspore/lite/src/sub_graph_split.h index bd805330c2c..50f27ebdfce 100644 --- a/mindspore/lite/src/sub_graph_split.h +++ b/mindspore/lite/src/sub_graph_split.h @@ -88,6 +88,10 @@ class SearchSubGraph { public: void SubGraphSplit(); +#ifdef OPERATOR_PARALLELISM + void SubGraphSplitByOperator(); + void InsertNodeBegin(uint32_t index, Subgraph *subgraph, std::vector *outputs); +#endif private: /* split by output */ void SubGraphSplitByOutput(); diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 4b0627d3e06..bbcb4cf69c1 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -35,6 +35,10 @@ file(GLOB_RECURSE TEST_UT_SRC ${TEST_DIR}/ut/src/api/tensor_c_test.cc ) +if(MSLITE_ENABLE_SERVER_INFERENCE) + list(REMOVE_ITEM TEST_UT_SRC ${TEST_DIR}/st/mindrt_parallel_runtime_test.cc) +endif() + if(MSLITE_ENABLE_RUNTIME_CONVERT) list(APPEND TEST_UT_SRC ${TEST_DIR}/ut/src/runtime/runtime_convert_tests.cc) endif() @@ -115,6 +119,12 @@ if(MSLITE_ENABLE_CONVERTER) ${TEST_DIR}/ut/src/dynamic_library_loader_test.cc ${TEST_DIR}/ut/tools/optimizer/fusion/*.cc ) + if(MSLITE_ENABLE_SERVER_INFERENCE) + list(REMOVE_ITEM TEST_CONVERTER_UT_SRC ${TEST_DIR}/st/mindrt_parallel_test.cc) + list(REMOVE_ITEM TEST_UT_SRC ${TEST_DIR}/st/benchmark_test.cc) + list(REMOVE_ITEM TEST_CONVERTER_UT_SRC ${TEST_DIR}/st/graph_test.cc) + list(REMOVE_ITEM TEST_CONVERTER_UT_SRC ${TEST_DIR}/st/sub_graph_test.cc) + endif() list(APPEND TEST_UT_SRC ${TEST_CONVERTER_UT_SRC}) set(TEST_LITE_SRC