fix memory leak

This commit is contained in:
simson 2020-12-21 15:42:29 +08:00
parent 858f2a5c9c
commit 759e43957e
6 changed files with 31 additions and 1 deletions

View File

@ -424,6 +424,15 @@ KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) const {
return it->second;
}
void SessionBasic::ClearGraph() {
auto graph_iter = graphs_.begin();
while (graph_iter != graphs_.end()) {
graph_iter->second.reset();
graphs_.erase(graph_iter++);
}
graph_sum_ = 0;
}
void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr &parameter) {
auto graph_id = GetGraphIdByNode(out_node);
if (graph_id == kInvalidGraphId) {

View File

@ -113,6 +113,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
const std::vector<tensor::TensorPtr> &inputs);
// Get graph by graph id, if not exist return null ptr
KernelGraphPtr GetGraph(GraphId graph_id) const;
void ClearGraph();
#ifdef ENABLE_DEBUGGER
// set debugger
void SetDebugger() {

View File

@ -1504,6 +1504,15 @@ void PynativeExecutor::ClearResidualRes(const std::string &cell_id) {
}
if (dynamic_cell_) {
VectorClear<std::vector<TopCellInfo>>(&top_cell_list_, cell_id);
if (IsTopGraph(cell_id) && graph_stack_.empty() && !IsBpropGraph(cell_id)) {
// Clear previous step resource
auto resource = GetResource(cell_id);
if (resource != nullptr && resource->results().find(pipeline::kBackend) != resource->results().end()) {
compile::BackendPtr backend = resource->results()[pipeline::kBackend].cast<compile::BackendPtr>();
auto ms_backend = std::dynamic_pointer_cast<compile::MsBackend>(backend);
ms_backend->ClearSessionGraphs();
}
}
}
}

View File

@ -122,6 +122,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args);
// Get info
bool GetIsDynamicCell() const { return dynamic_cell_; }
// Call by python
void Clear(const std::string &flag = "");
void Clean();

View File

@ -19,6 +19,7 @@
#include <vector>
#include "backend/session/session_factory.h"
#include "pipeline/pynative/pynative_execute.h"
#include "ir/anf.h"
#include "pybind_api/ir/base_ref_py.h"
#include "utils/callbacks.h"
@ -101,7 +102,9 @@ LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std:
result.graph_id = graph_id;
graph_id_map_[graph_id] = result;
(void)g_ConvertCache.emplace(segment, result);
if (!pynative::PynativeExecutor::GetInstance()->GetIsDynamicCell()) {
(void)g_ConvertCache.emplace(segment, result);
}
return result;
}
@ -214,6 +217,11 @@ GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return target_sess_-
VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); }
void MsBackend::ClearSessionGraphs() {
if (target_sess_ != nullptr) {
target_sess_->ClearGraph();
}
}
#ifdef ENABLE_DEBUGGER
void MsBackend::SetDebugger() { target_sess_->SetDebugger(); }
#endif

View File

@ -71,6 +71,7 @@ class MsBackend : public Backend {
void Link(GraphId) override;
GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override;
VectorRef RunGraph(GraphId graph_id, const VectorRef &args);
void ClearSessionGraphs();
void CreateOtherSession(const std::string &target);
#ifdef ENABLE_DEBUGGER