diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index dd261a582db..e33cc87bd30 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -424,6 +424,15 @@ KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) const { return it->second; } +void SessionBasic::ClearGraph() { + auto graph_iter = graphs_.begin(); + while (graph_iter != graphs_.end()) { + graph_iter->second.reset(); + graphs_.erase(graph_iter++); + } + graph_sum_ = 0; +} + void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter) { auto graph_id = GetGraphIdByNode(out_node); if (graph_id == kInvalidGraphId) { diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index d3b986531b8..531cc535d19 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -113,6 +113,7 @@ class SessionBasic : public std::enable_shared_from_this { const std::vector &inputs); // Get graph by graph id, if not exist return null ptr KernelGraphPtr GetGraph(GraphId graph_id) const; + void ClearGraph(); #ifdef ENABLE_DEBUGGER // set debugger void SetDebugger() { diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 0f539cc6e9c..91280937057 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -1504,6 +1504,15 @@ void PynativeExecutor::ClearResidualRes(const std::string &cell_id) { } if (dynamic_cell_) { VectorClear>(&top_cell_list_, cell_id); + if (IsTopGraph(cell_id) && graph_stack_.empty() && !IsBpropGraph(cell_id)) { + // Clear previous step resource + auto resource = GetResource(cell_id); + if (resource != nullptr && resource->results().find(pipeline::kBackend) != resource->results().end()) { + compile::BackendPtr backend = resource->results()[pipeline::kBackend].cast(); + auto ms_backend = std::dynamic_pointer_cast(backend); + ms_backend->ClearSessionGraphs(); + } + } } } diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index b74a8510960..b2a3dfbf51f 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -122,6 +122,8 @@ class PynativeExecutor : public std::enable_shared_from_this { void EndGraph(const py::object &cell, const py::object &out, const py::args &args); void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); + // Get info + bool GetIsDynamicCell() const { return dynamic_cell_; } // Call by python void Clear(const std::string &flag = ""); void Clean(); diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index 0265fe3e642..432223222f7 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -19,6 +19,7 @@ #include #include "backend/session/session_factory.h" +#include "pipeline/pynative/pynative_execute.h" #include "ir/anf.h" #include "pybind_api/ir/base_ref_py.h" #include "utils/callbacks.h" @@ -101,7 +102,9 @@ LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std: result.graph_id = graph_id; graph_id_map_[graph_id] = result; - (void)g_ConvertCache.emplace(segment, result); + if (!pynative::PynativeExecutor::GetInstance()->GetIsDynamicCell()) { + (void)g_ConvertCache.emplace(segment, result); + } return result; } @@ -214,6 +217,11 @@ GraphId MsBackend::CompileGraph(NotNull fg) { return target_sess_- VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); } +void MsBackend::ClearSessionGraphs() { + if (target_sess_ != nullptr) { + target_sess_->ClearGraph(); + } +} #ifdef ENABLE_DEBUGGER void MsBackend::SetDebugger() { target_sess_->SetDebugger(); } #endif diff --git a/mindspore/ccsrc/vm/backend.h b/mindspore/ccsrc/vm/backend.h index c7f1208057f..45f1d7f996c 100644 --- a/mindspore/ccsrc/vm/backend.h +++ b/mindspore/ccsrc/vm/backend.h @@ -71,6 +71,7 @@ class MsBackend : public Backend { void Link(GraphId) override; GraphId CompileGraph(NotNull fg) override; VectorRef RunGraph(GraphId graph_id, const VectorRef &args); + void ClearSessionGraphs(); void CreateOtherSession(const std::string &target); #ifdef ENABLE_DEBUGGER