forked from mindspore-Ecosystem/mindspore
fix memory leak
This commit is contained in:
parent
858f2a5c9c
commit
759e43957e
|
@ -424,6 +424,15 @@ KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) const {
|
|||
return it->second;
|
||||
}
|
||||
|
||||
void SessionBasic::ClearGraph() {
|
||||
auto graph_iter = graphs_.begin();
|
||||
while (graph_iter != graphs_.end()) {
|
||||
graph_iter->second.reset();
|
||||
graphs_.erase(graph_iter++);
|
||||
}
|
||||
graph_sum_ = 0;
|
||||
}
|
||||
|
||||
void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter) {
|
||||
auto graph_id = GetGraphIdByNode(out_node);
|
||||
if (graph_id == kInvalidGraphId) {
|
||||
|
|
|
@ -113,6 +113,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
const std::vector<tensor::TensorPtr> &inputs);
|
||||
// Get graph by graph id, if not exist return null ptr
|
||||
KernelGraphPtr GetGraph(GraphId graph_id) const;
|
||||
void ClearGraph();
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
// set debugger
|
||||
void SetDebugger() {
|
||||
|
|
|
@ -1504,6 +1504,15 @@ void PynativeExecutor::ClearResidualRes(const std::string &cell_id) {
|
|||
}
|
||||
if (dynamic_cell_) {
|
||||
VectorClear<std::vector<TopCellInfo>>(&top_cell_list_, cell_id);
|
||||
if (IsTopGraph(cell_id) && graph_stack_.empty() && !IsBpropGraph(cell_id)) {
|
||||
// Clear previous step resource
|
||||
auto resource = GetResource(cell_id);
|
||||
if (resource != nullptr && resource->results().find(pipeline::kBackend) != resource->results().end()) {
|
||||
compile::BackendPtr backend = resource->results()[pipeline::kBackend].cast<compile::BackendPtr>();
|
||||
auto ms_backend = std::dynamic_pointer_cast<compile::MsBackend>(backend);
|
||||
ms_backend->ClearSessionGraphs();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -122,6 +122,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
|
||||
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args);
|
||||
|
||||
// Get info
|
||||
bool GetIsDynamicCell() const { return dynamic_cell_; }
|
||||
// Call by python
|
||||
void Clear(const std::string &flag = "");
|
||||
void Clean();
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "backend/session/session_factory.h"
|
||||
#include "pipeline/pynative/pynative_execute.h"
|
||||
#include "ir/anf.h"
|
||||
#include "pybind_api/ir/base_ref_py.h"
|
||||
#include "utils/callbacks.h"
|
||||
|
@ -101,7 +102,9 @@ LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std:
|
|||
result.graph_id = graph_id;
|
||||
|
||||
graph_id_map_[graph_id] = result;
|
||||
(void)g_ConvertCache.emplace(segment, result);
|
||||
if (!pynative::PynativeExecutor::GetInstance()->GetIsDynamicCell()) {
|
||||
(void)g_ConvertCache.emplace(segment, result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -214,6 +217,11 @@ GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) { return target_sess_-
|
|||
|
||||
VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); }
|
||||
|
||||
void MsBackend::ClearSessionGraphs() {
|
||||
if (target_sess_ != nullptr) {
|
||||
target_sess_->ClearGraph();
|
||||
}
|
||||
}
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
void MsBackend::SetDebugger() { target_sess_->SetDebugger(); }
|
||||
#endif
|
||||
|
|
|
@ -71,6 +71,7 @@ class MsBackend : public Backend {
|
|||
void Link(GraphId) override;
|
||||
GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override;
|
||||
VectorRef RunGraph(GraphId graph_id, const VectorRef &args);
|
||||
void ClearSessionGraphs();
|
||||
void CreateOtherSession(const std::string &target);
|
||||
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
|
|
Loading…
Reference in New Issue