!19745 [MS][LITE]support model build with model buf
Merge pull request !19745 from 张学同/callback
This commit is contained in:
commit
2bfe69adca
|
@ -54,6 +54,10 @@ class MS_API Model {
|
|||
|
||||
static bool CheckModelSupport(enum DeviceType device_type, ModelType model_type);
|
||||
|
||||
Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
|
||||
const std::string &dec_mode = "AES-GCM");
|
||||
|
||||
private:
|
||||
// api without std::string
|
||||
MSTensor GetInputByTensorName(const std::vector<char> &tensor_name);
|
||||
|
|
|
@ -27,12 +27,6 @@
|
|||
#include "include/api/dual_abi_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
using Key = struct Key {
|
||||
const size_t max_key_len = 32;
|
||||
size_t len;
|
||||
unsigned char key[32];
|
||||
Key() : len(0) {}
|
||||
};
|
||||
|
||||
class MS_API Serialization {
|
||||
public:
|
||||
|
|
|
@ -144,6 +144,12 @@ MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vecto
|
|||
|
||||
std::string MSTensor::Name() const { return CharToString(CharName()); }
|
||||
|
||||
using Key = struct Key {
|
||||
const size_t max_key_len = 32;
|
||||
size_t len;
|
||||
unsigned char key[32];
|
||||
Key() : len(0) {}
|
||||
};
|
||||
/// \brief CallBackParam defined input arguments for callBack function.
|
||||
struct MSCallBackParam {
|
||||
std::string node_name_; /**< node name argument */
|
||||
|
|
|
@ -65,6 +65,11 @@ Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_
|
|||
return impl_->Build();
|
||||
}
|
||||
|
||||
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return kMCFailed;
|
||||
}
|
||||
Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed because this model has not been built.";
|
||||
|
|
|
@ -448,7 +448,7 @@ void GraphScheduler::Initialize() {
|
|||
size_t actor_thread_num = 0;
|
||||
size_t OMP_thread_num = 0;
|
||||
ComputeThreadNums(&actor_thread_num, &OMP_thread_num);
|
||||
thread_pool_ = ActorThreadPool::CreateThreadPool(actor_thread_num, kThreadSpin);
|
||||
thread_pool_ = ActorThreadPool::CreateThreadPool(actor_thread_num);
|
||||
MS_EXCEPTION_IF_NULL(thread_pool_);
|
||||
std::string OMP_env = std::to_string(OMP_thread_num);
|
||||
common::SetEnv("OMP_NUM_THREADS", OMP_env.c_str(), 0);
|
||||
|
|
|
@ -18,14 +18,10 @@
|
|||
#include "thread/core_affinity.h"
|
||||
|
||||
namespace mindspore {
|
||||
void ActorWorker::CreateThread(ActorThreadPool *pool, ThreadPolicy policy) {
|
||||
void ActorWorker::CreateThread(ActorThreadPool *pool) {
|
||||
THREAD_RETURN_IF_NULL(pool);
|
||||
pool_ = pool;
|
||||
if (policy == kThreadSpin) {
|
||||
thread_ = std::thread(&ActorWorker::RunWithSpin, this);
|
||||
} else if (policy == kThreadWait) {
|
||||
thread_ = std::thread(&ActorWorker::RunWithWait, this);
|
||||
}
|
||||
thread_ = std::thread(&ActorWorker::RunWithSpin, this);
|
||||
}
|
||||
|
||||
void ActorWorker::RunWithSpin() {
|
||||
|
@ -47,21 +43,6 @@ void ActorWorker::RunWithSpin() {
|
|||
}
|
||||
}
|
||||
|
||||
void ActorWorker::RunWithWait() {
|
||||
#ifndef __APPLE__
|
||||
static std::atomic_int index = {0};
|
||||
pthread_setname_np(pthread_self(), ("ActorThread_" + std::to_string(index++)).c_str());
|
||||
#endif
|
||||
while (alive_) {
|
||||
// only run PoolQueue ActorTask
|
||||
bool success = RunQueueActorTask();
|
||||
if (!success) {
|
||||
// wait until enqueue ActorTask
|
||||
pool_->WaitUntilNotify();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool ActorWorker::RunQueueActorTask() {
|
||||
THREAD_ERROR_IF_NULL(pool_);
|
||||
auto actor = pool_->PopActorFromQueue();
|
||||
|
@ -96,8 +77,6 @@ ActorThreadPool::~ActorThreadPool() {
|
|||
std::this_thread::yield();
|
||||
}
|
||||
} while (!terminate);
|
||||
exit_ = true;
|
||||
actor_cond_.notify_all();
|
||||
for (auto &worker : workers_) {
|
||||
delete worker;
|
||||
worker = nullptr;
|
||||
|
@ -105,11 +84,6 @@ ActorThreadPool::~ActorThreadPool() {
|
|||
workers_.clear();
|
||||
}
|
||||
|
||||
void ActorThreadPool::WaitUntilNotify() {
|
||||
std::unique_lock<std::mutex> _l(actor_mutex_);
|
||||
actor_cond_.wait(_l, [this] { return !actor_queue_.empty() || exit_; });
|
||||
}
|
||||
|
||||
ActorReference ActorThreadPool::PopActorFromQueue() {
|
||||
std::lock_guard<std::mutex> _l(actor_mutex_);
|
||||
if (actor_queue_.empty()) {
|
||||
|
@ -125,7 +99,6 @@ void ActorThreadPool::PushActorToQueue(const ActorReference &actor) {
|
|||
std::lock_guard<std::mutex> _l(actor_mutex_);
|
||||
actor_queue_.push(actor);
|
||||
}
|
||||
actor_cond_.notify_one();
|
||||
THREAD_INFO("actor[%s] enqueue success", actor->GetAID().Name().c_str());
|
||||
// active one idle actor thread if exist
|
||||
for (size_t i = 0; i < actor_thread_num_; ++i) {
|
||||
|
@ -136,7 +109,7 @@ void ActorThreadPool::PushActorToQueue(const ActorReference &actor) {
|
|||
}
|
||||
}
|
||||
|
||||
int ActorThreadPool::CreateThreads(size_t actor_thread_num, size_t all_thread_num, ThreadPolicy policy) {
|
||||
int ActorThreadPool::CreateThreads(size_t actor_thread_num, size_t all_thread_num) {
|
||||
size_t core_num = std::thread::hardware_concurrency();
|
||||
THREAD_INFO("ThreadInfo, Actor: [%zu], All: [%zu], CoreNum: [%zu]", actor_thread_num, all_thread_num, core_num);
|
||||
actor_thread_num_ = actor_thread_num < core_num ? actor_thread_num : core_num;
|
||||
|
@ -148,7 +121,7 @@ int ActorThreadPool::CreateThreads(size_t actor_thread_num, size_t all_thread_nu
|
|||
std::lock_guard<std::mutex> _l(pool_mutex_);
|
||||
auto worker = new (std::nothrow) ActorWorker();
|
||||
THREAD_ERROR_IF_NULL(worker);
|
||||
worker->CreateThread(this, policy);
|
||||
worker->CreateThread(this);
|
||||
workers_.push_back(worker);
|
||||
THREAD_INFO("create actor thread[%zu]", i);
|
||||
}
|
||||
|
@ -159,13 +132,12 @@ int ActorThreadPool::CreateThreads(size_t actor_thread_num, size_t all_thread_nu
|
|||
return THREAD_OK;
|
||||
}
|
||||
|
||||
ActorThreadPool *ActorThreadPool::CreateThreadPool(size_t actor_thread_num, size_t all_thread_num,
|
||||
ThreadPolicy policy) {
|
||||
ActorThreadPool *ActorThreadPool::CreateThreadPool(size_t actor_thread_num, size_t all_thread_num) {
|
||||
ActorThreadPool *pool = new (std::nothrow) ActorThreadPool();
|
||||
if (pool == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
int ret = pool->CreateThreads(actor_thread_num, all_thread_num, policy);
|
||||
int ret = pool->CreateThreads(actor_thread_num, all_thread_num);
|
||||
if (ret != THREAD_OK) {
|
||||
delete pool;
|
||||
return nullptr;
|
||||
|
@ -180,12 +152,12 @@ ActorThreadPool *ActorThreadPool::CreateThreadPool(size_t actor_thread_num, size
|
|||
return pool;
|
||||
}
|
||||
|
||||
ActorThreadPool *ActorThreadPool::CreateThreadPool(size_t thread_num, ThreadPolicy policy) {
|
||||
ActorThreadPool *ActorThreadPool::CreateThreadPool(size_t thread_num) {
|
||||
ActorThreadPool *pool = new (std::nothrow) ActorThreadPool();
|
||||
if (pool == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
int ret = pool->CreateThreads(thread_num, thread_num, policy);
|
||||
int ret = pool->CreateThreads(thread_num, thread_num);
|
||||
if (ret != THREAD_OK) {
|
||||
delete pool;
|
||||
return nullptr;
|
||||
|
|
|
@ -26,20 +26,15 @@
|
|||
#include "thread/hqueue.h"
|
||||
|
||||
namespace mindspore {
|
||||
enum ThreadPolicy {
|
||||
kThreadSpin = 0, // thread run in spin
|
||||
kThreadWait = 1 // synchronous and wait
|
||||
};
|
||||
|
||||
class ActorThreadPool;
|
||||
|
||||
class ActorWorker : public Worker {
|
||||
public:
|
||||
void CreateThread(ActorThreadPool *pool, ThreadPolicy policy);
|
||||
void CreateThread(ActorThreadPool *pool);
|
||||
bool Active();
|
||||
|
||||
private:
|
||||
void RunWithWait();
|
||||
void RunWithSpin();
|
||||
bool RunQueueActorTask();
|
||||
|
||||
|
@ -49,24 +44,22 @@ class ActorWorker : public Worker {
|
|||
class ActorThreadPool : public ThreadPool {
|
||||
public:
|
||||
// create ThreadPool that contains actor thread and kernel thread
|
||||
static ActorThreadPool *CreateThreadPool(size_t actor_thread_num, size_t all_thread_num, ThreadPolicy policy);
|
||||
static ActorThreadPool *CreateThreadPool(size_t actor_thread_num, size_t all_thread_num);
|
||||
// create ThreadPool that contains only actor thread
|
||||
static ActorThreadPool *CreateThreadPool(size_t thread_num, ThreadPolicy policy);
|
||||
static ActorThreadPool *CreateThreadPool(size_t thread_num);
|
||||
~ActorThreadPool() override;
|
||||
|
||||
void PushActorToQueue(const ActorReference &actor);
|
||||
ActorReference PopActorFromQueue();
|
||||
void WaitUntilNotify();
|
||||
|
||||
private:
|
||||
ActorThreadPool() {}
|
||||
int CreateThreads(size_t actor_thread_num, size_t all_thread_num, ThreadPolicy policy);
|
||||
int CreateThreads(size_t actor_thread_num, size_t all_thread_num);
|
||||
|
||||
size_t actor_thread_num_{0};
|
||||
|
||||
bool exit_{false};
|
||||
std::mutex actor_mutex_;
|
||||
std::condition_variable actor_cond_;
|
||||
// std::condition_variable actor_cond_;
|
||||
std::queue<ActorReference> actor_queue_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,6 +22,19 @@
|
|||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) {
|
||||
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
Status ret = impl_->Build(model_data, data_size, model_type, model_context);
|
||||
if (ret != kSuccess) {
|
||||
return ret;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
Status Model::Build(GraphCell graph, const std::shared_ptr<Context> &model_context) {
|
||||
if (impl_ != nullptr) {
|
||||
MS_LOG(DEBUG) << "Model has been already built.";
|
||||
|
|
|
@ -45,8 +45,8 @@ lite::CpuBindMode ModelImpl::GetCpuBindMode() {
|
|||
}
|
||||
}
|
||||
|
||||
Status ModelImpl::ConverterContext(lite::Context *model_context) {
|
||||
auto device_list = context_->MutableDeviceInfo();
|
||||
Status ModelImpl::ConverterContext(const std::shared_ptr<Context> &context, lite::Context *model_context) {
|
||||
auto device_list = context->MutableDeviceInfo();
|
||||
if (device_list.size() == 0) {
|
||||
MS_LOG(ERROR) << "Invalid device list.";
|
||||
return kLiteInputParamInvalid;
|
||||
|
@ -56,9 +56,9 @@ Status ModelImpl::ConverterContext(lite::Context *model_context) {
|
|||
return kLiteInputParamInvalid;
|
||||
}
|
||||
|
||||
model_context->thread_num_ = context_->GetThreadNum();
|
||||
model_context->enable_parallel_ = context_->GetEnableParallel();
|
||||
model_context->affinity_core_list_ = context_->GetThreadAffinityCoreList();
|
||||
model_context->thread_num_ = context->GetThreadNum();
|
||||
model_context->enable_parallel_ = context->GetEnableParallel();
|
||||
model_context->affinity_core_list_ = context->GetThreadAffinityCoreList();
|
||||
model_context->device_list_.clear();
|
||||
if (device_list[0]->GetDeviceType() != kCPU) {
|
||||
MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list.";
|
||||
|
@ -101,6 +101,26 @@ Status ModelImpl::ConverterContext(lite::Context *model_context) {
|
|||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &ms_context) {
|
||||
lite::Context lite_context;
|
||||
auto status = ConverterContext(ms_context, &lite_context);
|
||||
if (status != kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
auto session = std::shared_ptr<session::LiteSession>(
|
||||
session::LiteSession::CreateSession(static_cast<const char *>(model_data), data_size, &lite_context));
|
||||
if (session == nullptr) {
|
||||
MS_LOG(ERROR) << "Allocate session failed.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
|
||||
session_.swap(session);
|
||||
MS_LOG(DEBUG) << "Build model success.";
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelImpl::Build() {
|
||||
MS_LOG(DEBUG) << "Start build model.";
|
||||
auto model = graph_->graph_data_->lite_model();
|
||||
|
@ -117,7 +137,7 @@ Status ModelImpl::Build() {
|
|||
return kLiteNullptr;
|
||||
}
|
||||
lite::Context model_context;
|
||||
auto status = ConverterContext(&model_context);
|
||||
auto status = ConverterContext(context_, &model_context);
|
||||
if (status != kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
|
|
@ -36,6 +36,8 @@ class ModelImpl {
|
|||
~ModelImpl() = default;
|
||||
|
||||
Status Build();
|
||||
Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context);
|
||||
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
|
||||
|
||||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, const MSKernelCallBack &before,
|
||||
|
@ -58,7 +60,7 @@ class ModelImpl {
|
|||
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
|
||||
void SetContext(const std::shared_ptr<Context> &context) { context_ = context; }
|
||||
lite::CpuBindMode GetCpuBindMode();
|
||||
Status ConverterContext(lite::Context *model_context);
|
||||
Status ConverterContext(const std::shared_ptr<Context> &context, lite::Context *model_context);
|
||||
Status RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after);
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -67,7 +67,7 @@ int InnerContext::Init() {
|
|||
}
|
||||
if (this->thread_pool_ == nullptr && this->IsCpuEnabled()) {
|
||||
int actor_parallel_thread = this->enable_parallel_ ? 2 : 1;
|
||||
thread_pool_ = ActorThreadPool::CreateThreadPool(actor_parallel_thread, this->thread_num_, kThreadSpin);
|
||||
thread_pool_ = ActorThreadPool::CreateThreadPool(actor_parallel_thread, this->thread_num_);
|
||||
if (thread_pool_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Create ThreadPool failed";
|
||||
return RET_NULL_PTR;
|
||||
|
|
|
@ -455,15 +455,6 @@ bool LiteSession::IfUseMindrtExecutor() {
|
|||
use_mindrt_run = false;
|
||||
#endif
|
||||
|
||||
for (auto kernel : kernels_) {
|
||||
if (kernel->desc().delegate != nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
|
||||
if (sub_graph->nodes()[0]->type() == schema::PrimitiveType_Merge) {
|
||||
use_mindrt_run = false; /* control-flow model */
|
||||
}
|
||||
}
|
||||
return use_mindrt_run;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,99 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <utility>
|
||||
#include "src/runtime/parallel_executor.h"
|
||||
#include "src/lite_kernel_util.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
ParallelExecutor::~ParallelExecutor() { delete thread_pool_; }
|
||||
int ParallelExecutor::Prepare(const std::vector<mindspore::kernel::LiteKernel *> &kernels,
|
||||
const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
|
||||
const lite::InnerContext *ctx) {
|
||||
thread_pool_ = ActorThreadPool::CreateThreadPool(1, max_thread_num_, kThreadSpin);
|
||||
if (thread_pool_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Memory error: fail to new ThreadPool";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
static int RunKernel(void *data, int index, float lhs_scale, float rhs_scale) {
|
||||
auto *executor = reinterpret_cast<ParallelExecutor *>(data);
|
||||
auto kernel = executor->GetReadyKernel(index);
|
||||
auto ret = kernel->Execute();
|
||||
executor->SetResult(index, ret);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name();
|
||||
return 0;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ParallelExecutor::Run(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
const std::vector<kernel::LiteKernel *> &kernels, mindspore::Allocator *allocator,
|
||||
const KernelCallBack &before, const KernelCallBack &after) {
|
||||
MS_ASSERT(allocator != nullptr);
|
||||
for (auto &inTensor : in_tensors) {
|
||||
if (inTensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Graph input tensor is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (inTensor->format() != mindspore::NHWC) {
|
||||
MS_LOG(ERROR) << "Model input tensor should be NHWC";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
kernel::LiteKernelUtil::InitTensorInitRefCount(kernels);
|
||||
|
||||
for (auto kernel : kernels) {
|
||||
if (kernel->in_kernels().empty()) {
|
||||
readyKernels.emplace_back(kernel);
|
||||
continue;
|
||||
}
|
||||
refCount[kernel] = kernel->in_kernels().size();
|
||||
}
|
||||
std::vector<kernel::LiteKernel *> newReadyKernels;
|
||||
while (!readyKernels.empty()) {
|
||||
results.resize(readyKernels.size(), RET_OK);
|
||||
if (thread_pool_->ParallelLaunch(RunKernel, this, readyKernels.size()) != 0) {
|
||||
MS_LOG(ERROR) << "ParallelLaunch failed ";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (std::find_if(results.begin(), results.end(), [](const int &ret) { return (ret != 0); }) != results.end()) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
newReadyKernels.clear();
|
||||
for (auto completed : readyKernels) {
|
||||
for (auto out : completed->out_kernels()) {
|
||||
auto iter = refCount.find(out);
|
||||
if (iter == refCount.end()) {
|
||||
continue;
|
||||
}
|
||||
(iter->second)--;
|
||||
if (iter->second <= 0) {
|
||||
newReadyKernels.emplace_back(iter->first);
|
||||
refCount.erase(iter);
|
||||
}
|
||||
}
|
||||
}
|
||||
readyKernels.clear();
|
||||
readyKernels = std::move(newReadyKernels);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite
|
|
@ -1,53 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_PARALLEL_EXECUTOR_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_PARALLEL_EXECUTOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include "src/runtime/inner_allocator.h"
|
||||
#include "src/lite_kernel.h"
|
||||
#include "include/lite_session.h"
|
||||
#include "src/executor.h"
|
||||
#include "thread/actor_threadpool.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class ParallelExecutor : public Executor {
|
||||
public:
|
||||
ParallelExecutor() = default;
|
||||
~ParallelExecutor() override;
|
||||
|
||||
int Prepare(const std::vector<kernel::LiteKernel *> &kernels, const std::vector<Tensor *> &inputs,
|
||||
const std::vector<Tensor *> &outputs, const lite::InnerContext *ctx) override;
|
||||
|
||||
int Run(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
|
||||
const std::vector<kernel::LiteKernel *> &kernels, mindspore::Allocator *allocator = nullptr,
|
||||
const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override;
|
||||
inline kernel::LiteKernel *GetReadyKernel(const int index) const { return readyKernels.at(index); }
|
||||
inline void SetResult(const int index, const int result) { results.at(index) = result; }
|
||||
|
||||
private:
|
||||
std::unordered_map<kernel::LiteKernel *, size_t> refCount;
|
||||
std::vector<kernel::LiteKernel *> readyKernels;
|
||||
std::vector<int> results;
|
||||
ActorThreadPool *thread_pool_ = nullptr;
|
||||
int max_thread_num_ = std::thread::hardware_concurrency();
|
||||
};
|
||||
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_PARALLEL_EXECUTOR_H_
|
|
@ -84,7 +84,6 @@ set(TEST_LITE_SRC
|
|||
${OPS_SRC}
|
||||
${KERNEL_OP_SRC}
|
||||
${LITE_DIR}/src/runtime/inner_allocator.cc
|
||||
${LITE_DIR}/src/runtime/parallel_executor.cc
|
||||
${LITE_DIR}/src/runtime/infer_manager.cc
|
||||
${LITE_DIR}/src/tensor.cc
|
||||
${LITE_DIR}/src/ms_tensor.cc
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/lite_session.h"
|
||||
#include "src/runtime/parallel_executor.h"
|
||||
#include "tools/common/storage.h"
|
||||
#include "include/version.h"
|
||||
|
||||
|
|
|
@ -24,7 +24,6 @@
|
|||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/lite_session.h"
|
||||
#include "src/runtime/parallel_executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
class InferTest : public mindspore::CommonTest {
|
||||
|
@ -233,98 +232,6 @@ TEST_F(InferTest, TestAddNode) {
|
|||
MS_LOG(INFO) << "Passed";
|
||||
}
|
||||
|
||||
class SessionWithParallelExecutor : public lite::LiteSession {
|
||||
public:
|
||||
int Init(lite::InnerContext *context) {
|
||||
lite::LiteSession::Init(context);
|
||||
delete this->executor_;
|
||||
this->executor_ = new mindspore::lite::ParallelExecutor();
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(InferTest, TestParallelExecutor) {
|
||||
auto meta_graph = std::make_shared<schema::MetaGraphT>();
|
||||
meta_graph->name = "graph";
|
||||
|
||||
auto node = std::make_unique<schema::CNodeT>();
|
||||
node->inputIndex = {0, 1};
|
||||
node->outputIndex = {2};
|
||||
node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
node->primitive->value.type = schema::PrimitiveType_AddFusion;
|
||||
auto primitive = new schema::AddFusionT;
|
||||
node->primitive->value.value = primitive;
|
||||
node->name = "Add";
|
||||
meta_graph->nodes.emplace_back(std::move(node));
|
||||
meta_graph->inputIndex = {0, 1};
|
||||
meta_graph->outputIndex = {2};
|
||||
|
||||
auto input0 = std::make_unique<schema::TensorT>();
|
||||
input0->nodeType = lite::NodeType_ValueNode;
|
||||
input0->format = schema::Format_NHWC;
|
||||
input0->dataType = TypeId::kNumberTypeFloat32;
|
||||
input0->dims = {1, 28, 28, 3};
|
||||
input0->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(input0));
|
||||
|
||||
auto weight = std::make_unique<schema::TensorT>();
|
||||
weight->nodeType = lite::NodeType_ValueNode;
|
||||
weight->format = schema::Format_NHWC;
|
||||
weight->dataType = TypeId::kNumberTypeFloat32;
|
||||
weight->dims = {1, 28, 28, 3};
|
||||
|
||||
weight->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(weight));
|
||||
|
||||
auto output = std::make_unique<schema::TensorT>();
|
||||
output->nodeType = lite::NodeType_Parameter;
|
||||
output->format = schema::Format_NHWC;
|
||||
output->dataType = TypeId::kNumberTypeFloat32;
|
||||
output->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(output));
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
auto offset = schema::MetaGraph::Pack(builder, meta_graph.get());
|
||||
builder.Finish(offset);
|
||||
size_t size = builder.GetSize();
|
||||
const char *content = reinterpret_cast<char *>(builder.GetBufferPointer());
|
||||
|
||||
auto model = lite::Model::Import(content, size);
|
||||
ASSERT_NE(nullptr, model);
|
||||
meta_graph.reset();
|
||||
content = nullptr;
|
||||
auto context = new lite::InnerContext;
|
||||
auto &device_list = context->device_list_;
|
||||
lite::DeviceContext device_ctx = {lite::DT_CPU, {false, lite::NO_BIND}};
|
||||
device_list.push_back(device_ctx);
|
||||
context->thread_num_ = 4;
|
||||
ASSERT_EQ(lite::RET_OK, context->Init());
|
||||
auto session = new SessionWithParallelExecutor();
|
||||
session->Init(context);
|
||||
ASSERT_NE(nullptr, session);
|
||||
auto ret = session->CompileGraph(model);
|
||||
ASSERT_EQ(lite::RET_OK, ret);
|
||||
auto inputs = session->GetInputs();
|
||||
ASSERT_EQ(inputs.size(), 2);
|
||||
auto inTensor = inputs.front();
|
||||
ASSERT_NE(nullptr, inTensor);
|
||||
(void)inTensor->MutableData();
|
||||
auto inTensor1 = inputs.back();
|
||||
ASSERT_NE(nullptr, inTensor1);
|
||||
(void)inTensor1->MutableData();
|
||||
ret = session->RunGraph();
|
||||
ASSERT_EQ(lite::RET_OK, ret);
|
||||
auto outputs = session->GetOutputs();
|
||||
ASSERT_EQ(outputs.size(), 1);
|
||||
auto outTensor = outputs.begin()->second;
|
||||
ASSERT_NE(nullptr, outTensor);
|
||||
ASSERT_EQ(28 * 28 * 3, outTensor->ElementsNum());
|
||||
ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type());
|
||||
auto *outData = reinterpret_cast<float *>(outTensor->MutableData());
|
||||
ASSERT_NE(nullptr, outData);
|
||||
MS_LOG(INFO) << "Passed";
|
||||
}
|
||||
|
||||
TEST_F(InferTest, TestModel) {
|
||||
auto buf = new char *[1];
|
||||
size_t model_size;
|
||||
|
|
|
@ -110,7 +110,7 @@ class TestActor : public ActorBase {
|
|||
|
||||
TEST_F(LiteMindRtTest, ActorThreadPoolTest) {
|
||||
Initialize("", "", "", "", 4);
|
||||
auto pool = ActorThreadPool::CreateThreadPool(4, kThreadSpin);
|
||||
auto pool = ActorThreadPool::CreateThreadPool(4);
|
||||
AID t1 = Spawn(ActorReference(new TestActor("t1", pool, 1)));
|
||||
AID t2 = Spawn(ActorReference(new TestActor("t2", pool, 2)));
|
||||
AID t3 = Spawn(ActorReference(new TestActor("t3", pool, 3)));
|
||||
|
|
|
@ -24,7 +24,6 @@
|
|||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/lite_session.h"
|
||||
#include "src/runtime/parallel_executor.h"
|
||||
#include "include/registry/kernel_interface.h"
|
||||
#include "include/registry/register_kernel.h"
|
||||
|
||||
|
|
|
@ -23,7 +23,6 @@
|
|||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/lite_session.h"
|
||||
#include "src/runtime/parallel_executor.h"
|
||||
#include "src/runtime/inner_allocator.h"
|
||||
#include "include/registry/kernel_interface.h"
|
||||
#include "include/registry/register_kernel.h"
|
||||
|
|
Loading…
Reference in New Issue