!19745 [MS][LITE]support model build with model buf

Merge pull request !19745 from 张学同/callback
This commit is contained in:
i-robot 2021-07-09 02:02:44 +00:00 committed by Gitee
commit 2bfe69adca
20 changed files with 73 additions and 322 deletions

View File

@ -54,6 +54,10 @@ class MS_API Model {
static bool CheckModelSupport(enum DeviceType device_type, ModelType model_type);
Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
const std::string &dec_mode = "AES-GCM");
private:
// api without std::string
MSTensor GetInputByTensorName(const std::vector<char> &tensor_name);

View File

@ -27,12 +27,6 @@
#include "include/api/dual_abi_helper.h"
namespace mindspore {
using Key = struct Key {
const size_t max_key_len = 32;
size_t len;
unsigned char key[32];
Key() : len(0) {}
};
class MS_API Serialization {
public:

View File

@ -144,6 +144,12 @@ MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vecto
std::string MSTensor::Name() const { return CharToString(CharName()); }
using Key = struct Key {
const size_t max_key_len = 32;
size_t len;
unsigned char key[32];
Key() : len(0) {}
};
/// \brief CallBackParam defined input arguments for callBack function.
struct MSCallBackParam {
std::string node_name_; /**< node name argument */

View File

@ -65,6 +65,11 @@ Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_
return impl_->Build();
}
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) {
MS_LOG(ERROR) << "Unsupported Feature.";
return kMCFailed;
}
Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Failed because this model has not been built.";

View File

@ -448,7 +448,7 @@ void GraphScheduler::Initialize() {
size_t actor_thread_num = 0;
size_t OMP_thread_num = 0;
ComputeThreadNums(&actor_thread_num, &OMP_thread_num);
thread_pool_ = ActorThreadPool::CreateThreadPool(actor_thread_num, kThreadSpin);
thread_pool_ = ActorThreadPool::CreateThreadPool(actor_thread_num);
MS_EXCEPTION_IF_NULL(thread_pool_);
std::string OMP_env = std::to_string(OMP_thread_num);
common::SetEnv("OMP_NUM_THREADS", OMP_env.c_str(), 0);

View File

@ -18,14 +18,10 @@
#include "thread/core_affinity.h"
namespace mindspore {
void ActorWorker::CreateThread(ActorThreadPool *pool, ThreadPolicy policy) {
void ActorWorker::CreateThread(ActorThreadPool *pool) {
THREAD_RETURN_IF_NULL(pool);
pool_ = pool;
if (policy == kThreadSpin) {
thread_ = std::thread(&ActorWorker::RunWithSpin, this);
} else if (policy == kThreadWait) {
thread_ = std::thread(&ActorWorker::RunWithWait, this);
}
thread_ = std::thread(&ActorWorker::RunWithSpin, this);
}
void ActorWorker::RunWithSpin() {
@ -47,21 +43,6 @@ void ActorWorker::RunWithSpin() {
}
}
void ActorWorker::RunWithWait() {
#ifndef __APPLE__
static std::atomic_int index = {0};
pthread_setname_np(pthread_self(), ("ActorThread_" + std::to_string(index++)).c_str());
#endif
while (alive_) {
// only run PoolQueue ActorTask
bool success = RunQueueActorTask();
if (!success) {
// wait until enqueue ActorTask
pool_->WaitUntilNotify();
}
}
}
bool ActorWorker::RunQueueActorTask() {
THREAD_ERROR_IF_NULL(pool_);
auto actor = pool_->PopActorFromQueue();
@ -96,8 +77,6 @@ ActorThreadPool::~ActorThreadPool() {
std::this_thread::yield();
}
} while (!terminate);
exit_ = true;
actor_cond_.notify_all();
for (auto &worker : workers_) {
delete worker;
worker = nullptr;
@ -105,11 +84,6 @@ ActorThreadPool::~ActorThreadPool() {
workers_.clear();
}
void ActorThreadPool::WaitUntilNotify() {
std::unique_lock<std::mutex> _l(actor_mutex_);
actor_cond_.wait(_l, [this] { return !actor_queue_.empty() || exit_; });
}
ActorReference ActorThreadPool::PopActorFromQueue() {
std::lock_guard<std::mutex> _l(actor_mutex_);
if (actor_queue_.empty()) {
@ -125,7 +99,6 @@ void ActorThreadPool::PushActorToQueue(const ActorReference &actor) {
std::lock_guard<std::mutex> _l(actor_mutex_);
actor_queue_.push(actor);
}
actor_cond_.notify_one();
THREAD_INFO("actor[%s] enqueue success", actor->GetAID().Name().c_str());
// active one idle actor thread if exist
for (size_t i = 0; i < actor_thread_num_; ++i) {
@ -136,7 +109,7 @@ void ActorThreadPool::PushActorToQueue(const ActorReference &actor) {
}
}
int ActorThreadPool::CreateThreads(size_t actor_thread_num, size_t all_thread_num, ThreadPolicy policy) {
int ActorThreadPool::CreateThreads(size_t actor_thread_num, size_t all_thread_num) {
size_t core_num = std::thread::hardware_concurrency();
THREAD_INFO("ThreadInfo, Actor: [%zu], All: [%zu], CoreNum: [%zu]", actor_thread_num, all_thread_num, core_num);
actor_thread_num_ = actor_thread_num < core_num ? actor_thread_num : core_num;
@ -148,7 +121,7 @@ int ActorThreadPool::CreateThreads(size_t actor_thread_num, size_t all_thread_nu
std::lock_guard<std::mutex> _l(pool_mutex_);
auto worker = new (std::nothrow) ActorWorker();
THREAD_ERROR_IF_NULL(worker);
worker->CreateThread(this, policy);
worker->CreateThread(this);
workers_.push_back(worker);
THREAD_INFO("create actor thread[%zu]", i);
}
@ -159,13 +132,12 @@ int ActorThreadPool::CreateThreads(size_t actor_thread_num, size_t all_thread_nu
return THREAD_OK;
}
ActorThreadPool *ActorThreadPool::CreateThreadPool(size_t actor_thread_num, size_t all_thread_num,
ThreadPolicy policy) {
ActorThreadPool *ActorThreadPool::CreateThreadPool(size_t actor_thread_num, size_t all_thread_num) {
ActorThreadPool *pool = new (std::nothrow) ActorThreadPool();
if (pool == nullptr) {
return nullptr;
}
int ret = pool->CreateThreads(actor_thread_num, all_thread_num, policy);
int ret = pool->CreateThreads(actor_thread_num, all_thread_num);
if (ret != THREAD_OK) {
delete pool;
return nullptr;
@ -180,12 +152,12 @@ ActorThreadPool *ActorThreadPool::CreateThreadPool(size_t actor_thread_num, size
return pool;
}
ActorThreadPool *ActorThreadPool::CreateThreadPool(size_t thread_num, ThreadPolicy policy) {
ActorThreadPool *ActorThreadPool::CreateThreadPool(size_t thread_num) {
ActorThreadPool *pool = new (std::nothrow) ActorThreadPool();
if (pool == nullptr) {
return nullptr;
}
int ret = pool->CreateThreads(thread_num, thread_num, policy);
int ret = pool->CreateThreads(thread_num, thread_num);
if (ret != THREAD_OK) {
delete pool;
return nullptr;

View File

@ -26,20 +26,15 @@
#include "thread/hqueue.h"
namespace mindspore {
enum ThreadPolicy {
kThreadSpin = 0, // thread run in spin
kThreadWait = 1 // synchronous and wait
};
class ActorThreadPool;
class ActorWorker : public Worker {
public:
void CreateThread(ActorThreadPool *pool, ThreadPolicy policy);
void CreateThread(ActorThreadPool *pool);
bool Active();
private:
void RunWithWait();
void RunWithSpin();
bool RunQueueActorTask();
@ -49,24 +44,22 @@ class ActorWorker : public Worker {
class ActorThreadPool : public ThreadPool {
public:
// create ThreadPool that contains actor thread and kernel thread
static ActorThreadPool *CreateThreadPool(size_t actor_thread_num, size_t all_thread_num, ThreadPolicy policy);
static ActorThreadPool *CreateThreadPool(size_t actor_thread_num, size_t all_thread_num);
// create ThreadPool that contains only actor thread
static ActorThreadPool *CreateThreadPool(size_t thread_num, ThreadPolicy policy);
static ActorThreadPool *CreateThreadPool(size_t thread_num);
~ActorThreadPool() override;
void PushActorToQueue(const ActorReference &actor);
ActorReference PopActorFromQueue();
void WaitUntilNotify();
private:
ActorThreadPool() {}
int CreateThreads(size_t actor_thread_num, size_t all_thread_num, ThreadPolicy policy);
int CreateThreads(size_t actor_thread_num, size_t all_thread_num);
size_t actor_thread_num_{0};
bool exit_{false};
std::mutex actor_mutex_;
std::condition_variable actor_cond_;
// std::condition_variable actor_cond_;
std::queue<ActorReference> actor_queue_;
};
} // namespace mindspore

View File

@ -22,6 +22,19 @@
#include "src/common/log_adapter.h"
namespace mindspore {
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) {
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return kLiteNullptr;
}
Status ret = impl_->Build(model_data, data_size, model_type, model_context);
if (ret != kSuccess) {
return ret;
}
return kSuccess;
}
Status Model::Build(GraphCell graph, const std::shared_ptr<Context> &model_context) {
if (impl_ != nullptr) {
MS_LOG(DEBUG) << "Model has been already built.";

View File

@ -45,8 +45,8 @@ lite::CpuBindMode ModelImpl::GetCpuBindMode() {
}
}
Status ModelImpl::ConverterContext(lite::Context *model_context) {
auto device_list = context_->MutableDeviceInfo();
Status ModelImpl::ConverterContext(const std::shared_ptr<Context> &context, lite::Context *model_context) {
auto device_list = context->MutableDeviceInfo();
if (device_list.size() == 0) {
MS_LOG(ERROR) << "Invalid device list.";
return kLiteInputParamInvalid;
@ -56,9 +56,9 @@ Status ModelImpl::ConverterContext(lite::Context *model_context) {
return kLiteInputParamInvalid;
}
model_context->thread_num_ = context_->GetThreadNum();
model_context->enable_parallel_ = context_->GetEnableParallel();
model_context->affinity_core_list_ = context_->GetThreadAffinityCoreList();
model_context->thread_num_ = context->GetThreadNum();
model_context->enable_parallel_ = context->GetEnableParallel();
model_context->affinity_core_list_ = context->GetThreadAffinityCoreList();
model_context->device_list_.clear();
if (device_list[0]->GetDeviceType() != kCPU) {
MS_LOG(ERROR) << "CPU context must be enabled and in the first place of device list.";
@ -101,6 +101,26 @@ Status ModelImpl::ConverterContext(lite::Context *model_context) {
return kSuccess;
}
Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &ms_context) {
lite::Context lite_context;
auto status = ConverterContext(ms_context, &lite_context);
if (status != kSuccess) {
return status;
}
auto session = std::shared_ptr<session::LiteSession>(
session::LiteSession::CreateSession(static_cast<const char *>(model_data), data_size, &lite_context));
if (session == nullptr) {
MS_LOG(ERROR) << "Allocate session failed.";
return kLiteNullptr;
}
session_.swap(session);
MS_LOG(DEBUG) << "Build model success.";
return kSuccess;
}
Status ModelImpl::Build() {
MS_LOG(DEBUG) << "Start build model.";
auto model = graph_->graph_data_->lite_model();
@ -117,7 +137,7 @@ Status ModelImpl::Build() {
return kLiteNullptr;
}
lite::Context model_context;
auto status = ConverterContext(&model_context);
auto status = ConverterContext(context_, &model_context);
if (status != kSuccess) {
return status;
}

View File

@ -36,6 +36,8 @@ class ModelImpl {
~ModelImpl() = default;
Status Build();
Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context);
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, const MSKernelCallBack &before,
@ -58,7 +60,7 @@ class ModelImpl {
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
void SetContext(const std::shared_ptr<Context> &context) { context_ = context; }
lite::CpuBindMode GetCpuBindMode();
Status ConverterContext(lite::Context *model_context);
Status ConverterContext(const std::shared_ptr<Context> &context, lite::Context *model_context);
Status RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after);
};
} // namespace mindspore

View File

@ -67,7 +67,7 @@ int InnerContext::Init() {
}
if (this->thread_pool_ == nullptr && this->IsCpuEnabled()) {
int actor_parallel_thread = this->enable_parallel_ ? 2 : 1;
thread_pool_ = ActorThreadPool::CreateThreadPool(actor_parallel_thread, this->thread_num_, kThreadSpin);
thread_pool_ = ActorThreadPool::CreateThreadPool(actor_parallel_thread, this->thread_num_);
if (thread_pool_ == nullptr) {
MS_LOG(ERROR) << "Create ThreadPool failed";
return RET_NULL_PTR;

View File

@ -455,15 +455,6 @@ bool LiteSession::IfUseMindrtExecutor() {
use_mindrt_run = false;
#endif
for (auto kernel : kernels_) {
if (kernel->desc().delegate != nullptr) {
continue;
}
auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
if (sub_graph->nodes()[0]->type() == schema::PrimitiveType_Merge) {
use_mindrt_run = false; /* control-flow model */
}
}
return use_mindrt_run;
}

View File

@ -1,99 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <utility>
#include "src/runtime/parallel_executor.h"
#include "src/lite_kernel_util.h"
namespace mindspore::lite {
ParallelExecutor::~ParallelExecutor() { delete thread_pool_; }
int ParallelExecutor::Prepare(const std::vector<mindspore::kernel::LiteKernel *> &kernels,
const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
const lite::InnerContext *ctx) {
thread_pool_ = ActorThreadPool::CreateThreadPool(1, max_thread_num_, kThreadSpin);
if (thread_pool_ == nullptr) {
MS_LOG(ERROR) << "Memory error: fail to new ThreadPool";
return RET_ERROR;
}
return RET_OK;
}
static int RunKernel(void *data, int index, float lhs_scale, float rhs_scale) {
auto *executor = reinterpret_cast<ParallelExecutor *>(data);
auto kernel = executor->GetReadyKernel(index);
auto ret = kernel->Execute();
executor->SetResult(index, ret);
if (ret != RET_OK) {
MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name();
return 0;
}
return 0;
}
int ParallelExecutor::Run(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
const std::vector<kernel::LiteKernel *> &kernels, mindspore::Allocator *allocator,
const KernelCallBack &before, const KernelCallBack &after) {
MS_ASSERT(allocator != nullptr);
for (auto &inTensor : in_tensors) {
if (inTensor == nullptr) {
MS_LOG(ERROR) << "Graph input tensor is nullptr";
return RET_ERROR;
}
if (inTensor->format() != mindspore::NHWC) {
MS_LOG(ERROR) << "Model input tensor should be NHWC";
return RET_ERROR;
}
}
kernel::LiteKernelUtil::InitTensorInitRefCount(kernels);
for (auto kernel : kernels) {
if (kernel->in_kernels().empty()) {
readyKernels.emplace_back(kernel);
continue;
}
refCount[kernel] = kernel->in_kernels().size();
}
std::vector<kernel::LiteKernel *> newReadyKernels;
while (!readyKernels.empty()) {
results.resize(readyKernels.size(), RET_OK);
if (thread_pool_->ParallelLaunch(RunKernel, this, readyKernels.size()) != 0) {
MS_LOG(ERROR) << "ParallelLaunch failed ";
return RET_ERROR;
}
if (std::find_if(results.begin(), results.end(), [](const int &ret) { return (ret != 0); }) != results.end()) {
return RET_ERROR;
}
newReadyKernels.clear();
for (auto completed : readyKernels) {
for (auto out : completed->out_kernels()) {
auto iter = refCount.find(out);
if (iter == refCount.end()) {
continue;
}
(iter->second)--;
if (iter->second <= 0) {
newReadyKernels.emplace_back(iter->first);
refCount.erase(iter);
}
}
}
readyKernels.clear();
readyKernels = std::move(newReadyKernels);
}
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -1,53 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_PARALLEL_EXECUTOR_H_
#define MINDSPORE_LITE_SRC_RUNTIME_PARALLEL_EXECUTOR_H_
#include <vector>
#include <thread>
#include <unordered_map>
#include "src/runtime/inner_allocator.h"
#include "src/lite_kernel.h"
#include "include/lite_session.h"
#include "src/executor.h"
#include "thread/actor_threadpool.h"
namespace mindspore::lite {
class ParallelExecutor : public Executor {
public:
ParallelExecutor() = default;
~ParallelExecutor() override;
int Prepare(const std::vector<kernel::LiteKernel *> &kernels, const std::vector<Tensor *> &inputs,
const std::vector<Tensor *> &outputs, const lite::InnerContext *ctx) override;
int Run(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
const std::vector<kernel::LiteKernel *> &kernels, mindspore::Allocator *allocator = nullptr,
const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override;
inline kernel::LiteKernel *GetReadyKernel(const int index) const { return readyKernels.at(index); }
inline void SetResult(const int index, const int result) { results.at(index) = result; }
private:
std::unordered_map<kernel::LiteKernel *, size_t> refCount;
std::vector<kernel::LiteKernel *> readyKernels;
std::vector<int> results;
ActorThreadPool *thread_pool_ = nullptr;
int max_thread_num_ = std::thread::hardware_concurrency();
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_RUNTIME_PARALLEL_EXECUTOR_H_

View File

@ -84,7 +84,6 @@ set(TEST_LITE_SRC
${OPS_SRC}
${KERNEL_OP_SRC}
${LITE_DIR}/src/runtime/inner_allocator.cc
${LITE_DIR}/src/runtime/parallel_executor.cc
${LITE_DIR}/src/runtime/infer_manager.cc
${LITE_DIR}/src/tensor.cc
${LITE_DIR}/src/ms_tensor.cc

View File

@ -25,7 +25,6 @@
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "src/lite_session.h"
#include "src/runtime/parallel_executor.h"
#include "tools/common/storage.h"
#include "include/version.h"

View File

@ -24,7 +24,6 @@
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "src/lite_session.h"
#include "src/runtime/parallel_executor.h"
namespace mindspore {
class InferTest : public mindspore::CommonTest {
@ -233,98 +232,6 @@ TEST_F(InferTest, TestAddNode) {
MS_LOG(INFO) << "Passed";
}
class SessionWithParallelExecutor : public lite::LiteSession {
public:
int Init(lite::InnerContext *context) {
lite::LiteSession::Init(context);
delete this->executor_;
this->executor_ = new mindspore::lite::ParallelExecutor();
return 0;
}
};
TEST_F(InferTest, TestParallelExecutor) {
auto meta_graph = std::make_shared<schema::MetaGraphT>();
meta_graph->name = "graph";
auto node = std::make_unique<schema::CNodeT>();
node->inputIndex = {0, 1};
node->outputIndex = {2};
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_AddFusion;
auto primitive = new schema::AddFusionT;
node->primitive->value.value = primitive;
node->name = "Add";
meta_graph->nodes.emplace_back(std::move(node));
meta_graph->inputIndex = {0, 1};
meta_graph->outputIndex = {2};
auto input0 = std::make_unique<schema::TensorT>();
input0->nodeType = lite::NodeType_ValueNode;
input0->format = schema::Format_NHWC;
input0->dataType = TypeId::kNumberTypeFloat32;
input0->dims = {1, 28, 28, 3};
input0->offset = -1;
meta_graph->allTensors.emplace_back(std::move(input0));
auto weight = std::make_unique<schema::TensorT>();
weight->nodeType = lite::NodeType_ValueNode;
weight->format = schema::Format_NHWC;
weight->dataType = TypeId::kNumberTypeFloat32;
weight->dims = {1, 28, 28, 3};
weight->offset = -1;
meta_graph->allTensors.emplace_back(std::move(weight));
auto output = std::make_unique<schema::TensorT>();
output->nodeType = lite::NodeType_Parameter;
output->format = schema::Format_NHWC;
output->dataType = TypeId::kNumberTypeFloat32;
output->offset = -1;
meta_graph->allTensors.emplace_back(std::move(output));
flatbuffers::FlatBufferBuilder builder(1024);
auto offset = schema::MetaGraph::Pack(builder, meta_graph.get());
builder.Finish(offset);
size_t size = builder.GetSize();
const char *content = reinterpret_cast<char *>(builder.GetBufferPointer());
auto model = lite::Model::Import(content, size);
ASSERT_NE(nullptr, model);
meta_graph.reset();
content = nullptr;
auto context = new lite::InnerContext;
auto &device_list = context->device_list_;
lite::DeviceContext device_ctx = {lite::DT_CPU, {false, lite::NO_BIND}};
device_list.push_back(device_ctx);
context->thread_num_ = 4;
ASSERT_EQ(lite::RET_OK, context->Init());
auto session = new SessionWithParallelExecutor();
session->Init(context);
ASSERT_NE(nullptr, session);
auto ret = session->CompileGraph(model);
ASSERT_EQ(lite::RET_OK, ret);
auto inputs = session->GetInputs();
ASSERT_EQ(inputs.size(), 2);
auto inTensor = inputs.front();
ASSERT_NE(nullptr, inTensor);
(void)inTensor->MutableData();
auto inTensor1 = inputs.back();
ASSERT_NE(nullptr, inTensor1);
(void)inTensor1->MutableData();
ret = session->RunGraph();
ASSERT_EQ(lite::RET_OK, ret);
auto outputs = session->GetOutputs();
ASSERT_EQ(outputs.size(), 1);
auto outTensor = outputs.begin()->second;
ASSERT_NE(nullptr, outTensor);
ASSERT_EQ(28 * 28 * 3, outTensor->ElementsNum());
ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type());
auto *outData = reinterpret_cast<float *>(outTensor->MutableData());
ASSERT_NE(nullptr, outData);
MS_LOG(INFO) << "Passed";
}
TEST_F(InferTest, TestModel) {
auto buf = new char *[1];
size_t model_size;

View File

@ -110,7 +110,7 @@ class TestActor : public ActorBase {
TEST_F(LiteMindRtTest, ActorThreadPoolTest) {
Initialize("", "", "", "", 4);
auto pool = ActorThreadPool::CreateThreadPool(4, kThreadSpin);
auto pool = ActorThreadPool::CreateThreadPool(4);
AID t1 = Spawn(ActorReference(new TestActor("t1", pool, 1)));
AID t2 = Spawn(ActorReference(new TestActor("t2", pool, 2)));
AID t3 = Spawn(ActorReference(new TestActor("t3", pool, 3)));

View File

@ -24,7 +24,6 @@
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "src/lite_session.h"
#include "src/runtime/parallel_executor.h"
#include "include/registry/kernel_interface.h"
#include "include/registry/register_kernel.h"

View File

@ -23,7 +23,6 @@
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "src/lite_session.h"
#include "src/runtime/parallel_executor.h"
#include "src/runtime/inner_allocator.h"
#include "include/registry/kernel_interface.h"
#include "include/registry/register_kernel.h"