forked from mindspore-Ecosystem/mindspore
handle mem leak in pynative mode
use enter and leave construct
This commit is contained in:
parent
33faba7100
commit
1f20e552ed
|
@ -304,7 +304,7 @@ void ExecutorPy::DelNetRes(const std::string &id) {
|
|||
|
||||
void ExecutorPy::ClearRes() {
|
||||
MS_LOG(INFO) << "Clean executor resource!";
|
||||
Resource::ClearPrimitivePyPythonObj();
|
||||
Resource::mem_cleaner().ClearPrimitivePyPythonObj();
|
||||
executor_ = nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -275,39 +275,78 @@ Any Resource::GetAttrPtr(const TypeId &type, const std::string &name) {
|
|||
return GetMethodOrAttr(name, type_id, attr_map);
|
||||
}
|
||||
|
||||
std::unordered_map<PrimitivePy *, bool> Resource::py_objs_ = std::unordered_map<PrimitivePy *, bool>();
|
||||
void Resource::RecordPrimitivePy(PrimitivePy *prim) {
|
||||
MemoryCleaner Resource::mem_cleaner_ = MemoryCleaner();
|
||||
void MemoryCleaner::RecordPrimitivePy(PrimitivePy *prim) {
|
||||
if (prim == nullptr) {
|
||||
return;
|
||||
}
|
||||
py_objs_[prim] = true;
|
||||
all_primitives_[prim] = true;
|
||||
}
|
||||
|
||||
void Resource::ErasePrimitivePy(PrimitivePy *prim) {
|
||||
void MemoryCleaner::ErasePrimitivePy(PrimitivePy *prim) {
|
||||
if (prim == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto it = py_objs_.find(prim);
|
||||
if (it == py_objs_.end()) {
|
||||
auto it = all_primitives_.find(prim);
|
||||
if (it == all_primitives_.end()) {
|
||||
return;
|
||||
}
|
||||
// If flag is false,the pointer hased been released, so it can't be visited.
|
||||
if (!it->second) {
|
||||
return;
|
||||
}
|
||||
py_objs_[prim] = false;
|
||||
all_primitives_[prim] = false;
|
||||
prim->SetPyObj(py::none());
|
||||
}
|
||||
|
||||
void Resource::ClearPrimitivePyPythonObj() {
|
||||
for (auto &it : py_objs_) {
|
||||
void MemoryCleaner::ClearPrimitivePyPythonObj() {
|
||||
for (auto &it : all_primitives_) {
|
||||
if (it.second) {
|
||||
it.first->SetPyObj(py::none());
|
||||
}
|
||||
}
|
||||
py_objs_.clear();
|
||||
all_primitives_.clear();
|
||||
}
|
||||
|
||||
void MemoryCleaner::RecordPynativeShortLifePrimitivePy(PrimitivePy *prim) {
|
||||
if (prim == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (pynative_short_life_primitives_.find(prim) != pynative_short_life_primitives_.end()) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Record pynative tmp primitve:" << prim->ToString();
|
||||
pynative_short_life_primitives_.insert(prim);
|
||||
}
|
||||
|
||||
void MemoryCleaner::ErasePynativeShortLifePrimitivePy(PrimitivePy *prim) {
|
||||
if (prim == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (pynative_short_life_primitives_.find(prim) == pynative_short_life_primitives_.end()) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Erase pynative tmp primitive:" << prim->ToString();
|
||||
ErasePrimitivePy(prim);
|
||||
}
|
||||
|
||||
void MemoryCleaner::ClearPynativeShortLifePrimitivePy() {
|
||||
for (auto &primitive : pynative_short_life_primitives_) {
|
||||
ErasePynativeShortLifePrimitivePy(primitive);
|
||||
}
|
||||
pynative_short_life_primitives_.clear();
|
||||
}
|
||||
|
||||
void MemoryCleaner::EnterPynativeConstructProcess() { pynative_in_construct_process_ = true; }
|
||||
void MemoryCleaner::LeavePynativeConstructProcess() {
|
||||
pynative_in_construct_process_ = false;
|
||||
ClearPynativeShortLifePrimitivePy();
|
||||
}
|
||||
bool MemoryCleaner::IsInPynativeConstructProcess() const { return pynative_in_construct_process_; }
|
||||
void MemoryCleaner::EnterPynativeEndGraphProcess() { pynative_in_end_graph_process_ = true; }
|
||||
void MemoryCleaner::LeavePynativeEndGraphProcess() { pynative_in_end_graph_process_ = false; }
|
||||
bool MemoryCleaner::IsInPynativeEndGraphProcess() const { return pynative_in_end_graph_process_; }
|
||||
|
||||
void Resource::Clean() {
|
||||
// AbstractTensor->elements() will be saved in AbstractBasePtrList
|
||||
args_spec_.clear();
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
@ -52,6 +53,34 @@ BuiltInTypeMap &GetMethodMap();
|
|||
|
||||
BuiltInTypeMap &GetAttrMap();
|
||||
|
||||
class MemoryCleaner {
|
||||
public:
|
||||
MemoryCleaner() = default;
|
||||
~MemoryCleaner() = default;
|
||||
void RecordPrimitivePy(PrimitivePy *prim);
|
||||
void ErasePrimitivePy(PrimitivePy *prim);
|
||||
void ClearPrimitivePyPythonObj();
|
||||
|
||||
void RecordPynativeShortLifePrimitivePy(PrimitivePy *prim);
|
||||
void ErasePynativeShortLifePrimitivePy(PrimitivePy *prim);
|
||||
void ClearPynativeShortLifePrimitivePy();
|
||||
|
||||
void EnterPynativeConstructProcess();
|
||||
void LeavePynativeConstructProcess();
|
||||
bool IsInPynativeConstructProcess() const;
|
||||
void EnterPynativeEndGraphProcess();
|
||||
void LeavePynativeEndGraphProcess();
|
||||
bool IsInPynativeEndGraphProcess() const;
|
||||
|
||||
private:
|
||||
std::unordered_map<PrimitivePy *, bool> all_primitives_;
|
||||
// PrimitivePy objects that created in pynative construct process.These primitives should be released after construct
|
||||
// finished.
|
||||
std::unordered_set<PrimitivePy *> pynative_short_life_primitives_;
|
||||
bool pynative_in_construct_process_{false};
|
||||
bool pynative_in_end_graph_process_{false};
|
||||
};
|
||||
|
||||
class Resource : public ResourceBase {
|
||||
public:
|
||||
explicit Resource(const py::object &obj = py::none());
|
||||
|
@ -80,13 +109,11 @@ class Resource : public ResourceBase {
|
|||
}
|
||||
bool gpu_loopsink_flag() { return gpu_loopsink_flag_; }
|
||||
int64_t gpu_loopsink_size() { return gpu_loopsink_size_; }
|
||||
static void RecordPrimitivePy(PrimitivePy *prim);
|
||||
static void ErasePrimitivePy(PrimitivePy *prim);
|
||||
static void ClearPrimitivePyPythonObj();
|
||||
// Reclaim resource and clear the cache.
|
||||
// ExecutorPy::Compile() can be called multiple times, so cache
|
||||
// should be cleared.
|
||||
void Clean();
|
||||
static MemoryCleaner &mem_cleaner() { return mem_cleaner_; }
|
||||
|
||||
private:
|
||||
abstract::AnalysisEnginePtr engine_;
|
||||
|
@ -96,7 +123,8 @@ class Resource : public ResourceBase {
|
|||
bool is_cleaned_;
|
||||
bool gpu_loopsink_flag_{false};
|
||||
int64_t gpu_loopsink_size_{1};
|
||||
static std::unordered_map<PrimitivePy *, bool> py_objs_;
|
||||
// Used to handle mem leak objects.
|
||||
static MemoryCleaner mem_cleaner_;
|
||||
};
|
||||
|
||||
using ResourcePtr = std::shared_ptr<pipeline::Resource>;
|
||||
|
|
|
@ -2476,7 +2476,12 @@ void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) {
|
|||
}
|
||||
|
||||
void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) {
|
||||
MS_LOG(DEBUG) << "Enter end graph process.";
|
||||
auto &mem_cleaner = pipeline::Resource::mem_cleaner();
|
||||
mem_cleaner.EnterPynativeEndGraphProcess();
|
||||
PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args);
|
||||
mem_cleaner.LeavePynativeEndGraphProcess();
|
||||
MS_LOG(DEBUG) << "Leave end graph process.";
|
||||
}
|
||||
|
||||
void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
||||
|
@ -2491,6 +2496,24 @@ void PynativeExecutor::Sync() {
|
|||
session->SyncStream();
|
||||
}
|
||||
|
||||
void PynativeExecutor::EnterConstruct(const py::object &cell) {
|
||||
if (top_cell_ != nullptr) {
|
||||
return;
|
||||
}
|
||||
top_cell_ = cell.ptr();
|
||||
pipeline::Resource::mem_cleaner().EnterPynativeConstructProcess();
|
||||
MS_LOG(DEBUG) << "Enter construct process.";
|
||||
}
|
||||
|
||||
void PynativeExecutor::LeaveConstruct(const py::object &cell) {
|
||||
if (top_cell_ != cell.ptr()) {
|
||||
return;
|
||||
}
|
||||
top_cell_ = nullptr;
|
||||
pipeline::Resource::mem_cleaner().LeavePynativeConstructProcess();
|
||||
MS_LOG(DEBUG) << "Leave construct process.";
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
|
||||
(void)py::class_<PynativeExecutor, std::shared_ptr<PynativeExecutor>>(*m, "PynativeExecutor_")
|
||||
.def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.")
|
||||
|
@ -2502,6 +2525,10 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
|
|||
.def("sync", &PynativeExecutor::Sync, "pynative sync stream.")
|
||||
.def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.")
|
||||
.def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
|
||||
"Executor set grad flag.");
|
||||
"Executor set grad flag.")
|
||||
.def("enter_construct", &PynativeExecutor::EnterConstruct,
|
||||
"Do something before enter construct function.")
|
||||
.def("leave_construct", &PynativeExecutor::LeaveConstruct,
|
||||
"Do something after leave construct function.");
|
||||
}));
|
||||
} // namespace mindspore::pynative
|
||||
|
|
|
@ -108,6 +108,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
bool need_replace_forward() const { return need_replace_forward_; }
|
||||
bool grad_flag() const { return grad_flag_; }
|
||||
void set_grad_flag(bool flag) { grad_flag_ = flag; }
|
||||
void EnterConstruct(const py::object &cell);
|
||||
void LeaveConstruct(const py::object &cell);
|
||||
|
||||
py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info);
|
||||
OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
|
||||
|
@ -263,6 +265,12 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
bool dynamic_cell_{false};
|
||||
bool grad_is_running_{false};
|
||||
bool need_replace_forward_{true};
|
||||
// The pointer of top python Cell object, which is always the network(inherit class Cell) ran in python test script,
|
||||
// such as Resnet50(Cell),LeNet(Cell).This pointer is used to distinguish temporary primitives from global
|
||||
// primitives to control memory release. Global primitives are always created in top cell's '__init__' function and
|
||||
// temporary primitives are always created in other place.Temporary primitives will be released after executing top
|
||||
// cell's 'construct' function but global primitives will not.
|
||||
PyObject *top_cell_{nullptr};
|
||||
|
||||
// Used for construct grad graph
|
||||
FuncGraphPtr curr_g_{nullptr};
|
||||
|
|
|
@ -54,9 +54,18 @@ std::map<std::string, py::object> PrimitivePy::hook_grad_;
|
|||
|
||||
PrimitivePy::PrimitivePy(const py::str &name, const py::object &python_obj)
|
||||
: Primitive(name, false), python_obj_(python_obj), signatures_() {
|
||||
pipeline::Resource::RecordPrimitivePy(this);
|
||||
auto &mem_cleaner = pipeline::Resource::mem_cleaner();
|
||||
mem_cleaner.RecordPrimitivePy(this);
|
||||
if (mem_cleaner.IsInPynativeConstructProcess() && !mem_cleaner.IsInPynativeEndGraphProcess()) {
|
||||
mem_cleaner.RecordPynativeShortLifePrimitivePy(this);
|
||||
}
|
||||
}
|
||||
PrimitivePy::~PrimitivePy() {
|
||||
// Erase primitive here to set released flag false, to avoid calling released pointer when clear primitives in
|
||||
// resource.
|
||||
pipeline::Resource::mem_cleaner().ErasePrimitivePy(this);
|
||||
MS_LOG(DEBUG) << "Release:" << ToString();
|
||||
}
|
||||
PrimitivePy::~PrimitivePy() { pipeline::Resource::ErasePrimitivePy(this); }
|
||||
void PrimitivePy::SetPyObj(const py::object &obj) { python_obj_ = obj; }
|
||||
void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) {
|
||||
signatures_ = signatures;
|
||||
|
|
|
@ -321,6 +321,12 @@ class _PynativeExecutor:
|
|||
def set_grad_flag(self, flag):
|
||||
self._executor.set_grad_flag(flag)
|
||||
|
||||
def enter_construct(self, cell):
|
||||
self._executor.enter_construct(cell)
|
||||
|
||||
def leave_construct(self, cell):
|
||||
self._executor.leave_construct(cell)
|
||||
|
||||
def __call__(self, obj, *args, **kwargs):
|
||||
args = args + tuple(kwargs.values())
|
||||
return self._executor(obj, args, "")
|
||||
|
|
|
@ -352,9 +352,13 @@ class Cell(Cell_):
|
|||
if not cast_inputs:
|
||||
cast_inputs = inputs
|
||||
if self.enable_hook:
|
||||
_pynative_exec.enter_construct(self)
|
||||
output = self._hook_construct(*cast_inputs, **kwargs)
|
||||
_pynative_exec.leave_construct(self)
|
||||
else:
|
||||
_pynative_exec.enter_construct(self)
|
||||
output = self.construct(*cast_inputs, **kwargs)
|
||||
_pynative_exec.leave_construct(self)
|
||||
if isinstance(output, Parameter):
|
||||
output = output.data
|
||||
if self.requires_grad is True:
|
||||
|
|
|
@ -19,7 +19,6 @@ import pytest
|
|||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops.composite import GradOperation
|
||||
|
||||
|
@ -29,7 +28,6 @@ class NetSigmoidGrad(nn.Cell):
|
|||
super(NetSigmoidGrad, self).__init__()
|
||||
self.sigmoid_grad = G.SigmoidGrad()
|
||||
|
||||
@ms_function
|
||||
def construct(self, y, dy):
|
||||
return self.sigmoid_grad(y, dy)
|
||||
|
||||
|
@ -40,7 +38,6 @@ class Grad(nn.Cell):
|
|||
self.grad = GradOperation(get_all=True, sens_param=True)
|
||||
self.network = network
|
||||
|
||||
@ms_function
|
||||
def construct(self, y, y_grad, dout):
|
||||
return self.grad(self.network)(y, y_grad, dout)
|
||||
|
||||
|
|
Loading…
Reference in New Issue