forked from mindspore-Ecosystem/mindspore
!4203 fix bug that some graph output tensor being freed
Merge pull request !4203 from hangq/master
This commit is contained in:
commit
ed2b336289
|
@ -58,7 +58,10 @@ int Executor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tensor::Ten
|
|||
}
|
||||
}
|
||||
for (auto input_kernel : kernel->GetInKernels()) {
|
||||
MS_EXCEPTION_IF_NULL(input_kernel);
|
||||
MS_ASSERT(input_kernel != nullptr);
|
||||
if (input_kernel->is_model_output()) {
|
||||
continue;
|
||||
}
|
||||
ret = input_kernel->DecOutTensorRefCount();
|
||||
if (0 != ret) {
|
||||
MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->Name() << " failed";
|
||||
|
|
|
@ -60,8 +60,7 @@ class LiteKernel {
|
|||
explicit LiteKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
|
||||
const lite::Primitive *primitive)
|
||||
: opParameter(parameter), inputs_(inputs), outputs_(outputs), primitive_(primitive),
|
||||
context_(ctx) {
|
||||
: opParameter(parameter), inputs_(inputs), outputs_(outputs), primitive_(primitive), context_(ctx) {
|
||||
opParameter->thread_num_ = ctx->thread_num_;
|
||||
this->in_kernel_.clear();
|
||||
this->out_kernel_.clear();
|
||||
|
@ -95,6 +94,10 @@ class LiteKernel {
|
|||
virtual bool is_eval() { return train_mode == false; }
|
||||
void set_name(const std::string &name) { this->name = name; }
|
||||
|
||||
void set_is_model_output(bool is_model_output) { this->is_model_output_ = is_model_output; }
|
||||
|
||||
bool is_model_output() { return this->is_model_output_; }
|
||||
|
||||
schema::PrimitiveType type() { return (schema::PrimitiveType)this->opParameter->type_; }
|
||||
|
||||
std::string type_str() { return schema::EnumNamePrimitiveType((schema::PrimitiveType)this->opParameter->type_); }
|
||||
|
@ -123,9 +126,7 @@ class LiteKernel {
|
|||
|
||||
void set_desc(const KernelKey kernel_key) { desc = kernel_key; }
|
||||
|
||||
void SetNeedReInit() {
|
||||
need_reinit = true;
|
||||
}
|
||||
void SetNeedReInit() { need_reinit = true; }
|
||||
|
||||
protected:
|
||||
bool InferShapeDone() {
|
||||
|
@ -138,8 +139,8 @@ class LiteKernel {
|
|||
KernelKey desc;
|
||||
std::string name;
|
||||
OpParameter *opParameter = nullptr;
|
||||
const lite::Primitive *primitive_;
|
||||
const lite::Context *context_;
|
||||
const lite::Primitive *primitive_ = nullptr;
|
||||
const lite::Context *context_ = nullptr;
|
||||
// tensor will free in ~lite_session()
|
||||
std::vector<lite::tensor::Tensor *> inputs_;
|
||||
std::vector<lite::tensor::Tensor *> outputs_;
|
||||
|
@ -147,6 +148,7 @@ class LiteKernel {
|
|||
std::vector<LiteKernel *> out_kernel_;
|
||||
bool train_mode = false;
|
||||
bool need_reinit = false;
|
||||
bool is_model_output_ = false;
|
||||
};
|
||||
|
||||
class SubGraphKernel : public LiteKernel {
|
||||
|
|
|
@ -79,7 +79,33 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void LiteSession::InitGraphInOutTensor(const lite::Model *model) {
|
||||
void LiteSession::InitGraphInputTensors(const lite::Model *model) {
|
||||
auto meta_graph = model->GetMetaGraph();
|
||||
MS_ASSERT(this->inputs.empty());
|
||||
MS_ASSERT(meta_graph != nullptr);
|
||||
for (size_t i = 0; i < meta_graph->inputIndex()->size(); i++) {
|
||||
auto in_tensor_idx = size_t(meta_graph->inputIndex()->GetAs<uint32_t>(i));
|
||||
MS_ASSERT(in_tensor_idx < this->tensors.size());
|
||||
auto *in_tensor = this->tensors.at(in_tensor_idx);
|
||||
MS_ASSERT(in_tensor != nullptr);
|
||||
this->inputs.emplace_back(in_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
void LiteSession::InitGraphOutputTensors(const lite::Model *model) {
|
||||
auto meta_graph = model->GetMetaGraph();
|
||||
MS_ASSERT(this->outputs.empty());
|
||||
MS_ASSERT(meta_graph != nullptr);
|
||||
for (size_t i = 0; i < meta_graph->outputIndex()->size(); i++) {
|
||||
auto out_tensor_idx = size_t(meta_graph->outputIndex()->GetAs<uint32_t>(i));
|
||||
MS_ASSERT(out_tensor_idx < this->tensors.size());
|
||||
auto *out_tensor = this->tensors.at(out_tensor_idx);
|
||||
MS_ASSERT(out_tensor != nullptr);
|
||||
this->outputs.emplace_back(out_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
void LiteSession::InitGraphInputMap(const lite::Model *model) {
|
||||
auto meta_graph = model->GetMetaGraph();
|
||||
MS_ASSERT(this->input_map.empty());
|
||||
MS_ASSERT(meta_graph != nullptr);
|
||||
|
@ -108,7 +134,12 @@ void LiteSession::InitGraphInOutTensor(const lite::Model *model) {
|
|||
this->input_map[in_node->name()->str()].emplace_back(ms_tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void LiteSession::InitGraphOutputMap(const lite::Model *model) {
|
||||
auto meta_graph = model->GetMetaGraph();
|
||||
MS_ASSERT(this->output_map.empty());
|
||||
MS_ASSERT(meta_graph != nullptr);
|
||||
auto graph_output_node_indexes = GetGraphOutputNodes(meta_graph);
|
||||
for (auto out_node_index : graph_output_node_indexes) {
|
||||
auto *out_node = meta_graph->nodes()->GetAs<schema::CNode>(out_node_index);
|
||||
|
@ -136,6 +167,13 @@ void LiteSession::InitGraphInOutTensor(const lite::Model *model) {
|
|||
}
|
||||
}
|
||||
|
||||
void LiteSession::InitGraphInOutTensors(const lite::Model *model) {
|
||||
InitGraphInputTensors(model);
|
||||
InitGraphOutputTensors(model);
|
||||
InitGraphInputMap(model);
|
||||
InitGraphOutputMap(model);
|
||||
}
|
||||
|
||||
int LiteSession::CompileGraph(Model *model) {
|
||||
// model.MetaGraph ==> kernels
|
||||
if (model == nullptr) {
|
||||
|
@ -149,7 +187,7 @@ int LiteSession::CompileGraph(Model *model) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
InitGraphInOutTensor(model);
|
||||
InitGraphInOutTensors(model);
|
||||
|
||||
// scheduler kernels
|
||||
Scheduler scheduler(context_);
|
||||
|
@ -228,15 +266,7 @@ LiteSession::~LiteSession() {
|
|||
}
|
||||
delete tensor;
|
||||
}
|
||||
// inputs outputs input_map output_map are freed in tensors
|
||||
for (auto *input : inputs) {
|
||||
((tensor::LiteTensor *)input)->SetTensorImpl(nullptr);
|
||||
delete input;
|
||||
}
|
||||
for (auto *output : outputs) {
|
||||
((tensor::LiteTensor *)output)->SetTensorImpl(nullptr);
|
||||
delete output;
|
||||
}
|
||||
// tensor::Tensor * in input_map output_map are freed in tensors
|
||||
for (auto iter : this->input_map) {
|
||||
for (auto *ms_tensor : iter.second) {
|
||||
((tensor::LiteTensor *)ms_tensor)->SetTensorImpl(nullptr);
|
||||
|
|
|
@ -56,7 +56,15 @@ class LiteSession : public session::LiteSession {
|
|||
protected:
|
||||
int ConvertTensors(const lite::Model *model);
|
||||
|
||||
void InitGraphInOutTensor(const lite::Model *model);
|
||||
void InitGraphInOutTensors(const lite::Model *model);
|
||||
|
||||
void InitGraphInputTensors(const lite::Model *model);
|
||||
|
||||
void InitGraphOutputTensors(const lite::Model *model);
|
||||
|
||||
void InitGraphInputMap(const lite::Model *model);
|
||||
|
||||
void InitGraphOutputMap(const lite::Model *model);
|
||||
|
||||
protected:
|
||||
Context *context_ = nullptr;
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
#include <algorithm>
|
||||
#include "include/errorcode.h"
|
||||
#include "src/kernel_factory.h"
|
||||
#include "src/common/graph_util.h"
|
||||
#include "src/common/utils.h"
|
||||
#if SUPPORT_GPU
|
||||
#include "src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
|
||||
#endif
|
||||
|
@ -51,6 +53,7 @@ int Scheduler::InitOp2Kernel(const lite::Model *model, std::vector<tensor::Tenso
|
|||
auto meta_graph = model->GetMetaGraph();
|
||||
MS_EXCEPTION_IF_NULL(meta_graph);
|
||||
uint32_t kernelCount = meta_graph->nodes()->size();
|
||||
auto graph_output_node_indexes = GetGraphOutputNodes(meta_graph);
|
||||
for (uint32_t i = 0; i < kernelCount; i++) {
|
||||
auto cNode = meta_graph->nodes()->GetAs<schema::CNode>(i);
|
||||
std::vector<tensor::Tensor *> inputs;
|
||||
|
@ -93,6 +96,7 @@ int Scheduler::InitOp2Kernel(const lite::Model *model, std::vector<tensor::Tenso
|
|||
return RET_ERROR;
|
||||
}
|
||||
kernel->set_name(cNode->name()->str());
|
||||
kernel->set_is_model_output(IsContain(graph_output_node_indexes, size_t(i)));
|
||||
kernels->emplace_back(kernel);
|
||||
}
|
||||
return RET_OK;
|
||||
|
|
Loading…
Reference in New Issue