[MSLITE] train session mindrt

This commit is contained in:
ling 2021-09-02 17:11:30 +08:00
parent 6d0bdd83da
commit 2838c10010
2 changed files with 38 additions and 49 deletions

View File

@ -467,17 +467,6 @@ void LiteSession::FreePackOpWeight(const std::vector<kernel::LiteKernel *> &kern
}
}
bool LiteSession::IfUseMindrtExecutor() {
bool use_mindrt_run = true;
#ifdef ENABLE_MINDRT
use_mindrt_run = (is_train_session_) ? false : true;
#else
use_mindrt_run = false;
#endif
return use_mindrt_run;
}
int LiteSession::CompileGraph(Model *model) {
bool expected = false;
if (!is_running_.compare_exchange_strong(expected, true)) {
@ -526,31 +515,29 @@ int LiteSession::CompileGraph(Model *model) {
}
InitGraphInOutTensorsMap(model);
bool use_mindrt_run = IfUseMindrtExecutor();
ret = PrepareKernels(model, use_mindrt_run);
ret = PrepareKernels(model);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare kernels failed: " << ret;
is_running_.store(false);
return ret;
}
#ifdef ENABLE_MINDRT
if (use_mindrt_run) {
ret = IsolateOutputTensor();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Isolate output tensor failed.";
is_running_.store(false);
return ret;
}
executor_ = new (std::nothrow) MindrtExecutor(&graph_output_map_);
} else {
#endif
executor_ = new (std::nothrow) Executor();
#ifdef ENABLE_MINDRT
if (is_train_session_) {
is_running_.store(false);
return RET_OK;
}
#endif
#ifdef ENABLE_MINDRT
ret = IsolateOutputTensor();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Isolate output tensor failed.";
is_running_.store(false);
return ret;
}
executor_ = new (std::nothrow) MindrtExecutor(&graph_output_map_);
#else
executor_ = new (std::nothrow) Executor();
#endif
if (executor_ == nullptr) {
MS_LOG(ERROR) << "New Executor failed";
is_running_.store(false);
@ -563,10 +550,10 @@ int LiteSession::CompileGraph(Model *model) {
is_running_.store(false);
return ret;
}
if (!is_train_session_) {
// For reducing runtime RAM, free packop weight because packop will pack weight and will not access to origin weight
FreePackOpWeight(kernels_);
}
// For reducing runtime RAM, free packop weight because packop will pack weight and will not access to origin weight
FreePackOpWeight(kernels_);
is_running_.store(false);
return RET_OK;
}
@ -587,30 +574,34 @@ bool LiteSession::IsIsolatedSubGraph(kernel::LiteKernel *kernel) {
return true;
}
int LiteSession::PrepareKernels(Model *model, bool use_mindrt_run) {
int LiteSession::PrepareKernels(Model *model) {
std::vector<kernel::LiteKernel *> all_kernels;
// find in_kernels and out_kernels for subgraphs
for (auto kernel : this->kernels_) {
kernel->FindInoutKernels(this->kernels_);
#ifndef DELEGATE_CLIP
if (kernel->desc().arch == kernel::kDelegate) {
all_kernels.push_back(kernel);
} else {
#endif
auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
MS_ASSERT(sub_graph != nullptr);
auto kernel_in_subgraph = sub_graph->nodes();
all_kernels.insert(all_kernels.end(), kernel_in_subgraph.begin(), kernel_in_subgraph.end());
#ifndef DELEGATE_CLIP
continue;
}
#endif
auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
MS_ASSERT(sub_graph != nullptr);
auto kernel_in_subgraph = sub_graph->nodes();
all_kernels.insert(all_kernels.end(), kernel_in_subgraph.begin(), kernel_in_subgraph.end());
}
if (!use_mindrt_run) {
// find in_kernels and out_kernels for kernels
for (auto *kernel : all_kernels) {
kernel->FindInoutKernels(all_kernels);
// find in_sub and out_sub for subgraph
for (auto kernel : this->kernels_) {
kernel->FindInoutKernels(this->kernels_);
}
// find in_kernels and out_kernels for kernels
for (auto *kernel : all_kernels) {
#ifndef DELEGATE_CLIP
if (kernel->desc().arch == kernel::kDelegate) {
continue;
}
#endif
kernel->FindInoutKernels(all_kernels);
}
// init init_ref_count for subgraphs and kernels

View File

@ -116,7 +116,7 @@ class LiteSession : public session::LiteSession {
int ResizeInputs(const std::vector<mindspore::tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims);
int PrepareKernels(Model *model, bool use_mindrt_run);
int PrepareKernels(Model *model);
static int ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels);
@ -127,8 +127,6 @@ class LiteSession : public session::LiteSession {
int InitGPURuntime();
bool IfUseMindrtExecutor();
bool IsIsolatedSubGraph(kernel::LiteKernel *kernel);
protected: