!20787 Memory leak fix.

Merge pull request !20787 from zhangzhaoju/master_leak_fix
This commit is contained in:
i-robot 2021-08-11 11:46:22 +00:00 committed by Gitee
commit 50d54a7482
10 changed files with 46 additions and 30 deletions

View File

@ -25,7 +25,7 @@
#include <vector>
#include <algorithm>
#include "ir/anf.h"
#include "pipeline/jit/prim_bprop_optimizer.h"
#include "frontend/optimizer/ad/prim_bprop_optimizer.h"
#include "frontend/optimizer/ad/adjoint.h"
#include "frontend/optimizer/ad/dfunctor.h"
#include "frontend/optimizer/ad/kpynative.h"
@ -90,8 +90,11 @@ FuncGraphPtr GetZerosLike(const abstract::AbstractBasePtrList &args_spec) {
MS_EXCEPTION_IF_NULL(specialized_zeros_like_fg);
auto opted_zeros_like_fg = ZerosLikePrimOptPass(resource);
MS_EXCEPTION_IF_NULL(opted_zeros_like_fg);
zeros_like_funcgraph_cache[args_spec] = opted_zeros_like_fg;
return BasicClone(opted_zeros_like_fg);
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_GRAD_CACHE);
if (enable_grad_cache) {
zeros_like_funcgraph_cache[args_spec] = BasicClone(opted_zeros_like_fg);
}
return opted_zeros_like_fg;
}
FuncGraphPtr GetHyperAdd(const abstract::AbstractBasePtrList &args_spec) {
@ -146,8 +149,11 @@ FuncGraphPtr GetOnesLike(const abstract::AbstractBasePtrList &args_spec) {
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
auto specialized_ones_like_fg = pipeline::Renormalize(resource, ones_like_fg, args_spec);
MS_EXCEPTION_IF_NULL(specialized_ones_like_fg);
ones_like_funcgraph_cache[args_spec] = specialized_ones_like_fg;
return BasicClone(specialized_ones_like_fg);
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_GRAD_CACHE);
if (enable_grad_cache) {
ones_like_funcgraph_cache[args_spec] = BasicClone(specialized_ones_like_fg);
}
return specialized_ones_like_fg;
}
AnfNodePtr BuildOnesLikeValue(const FuncGraphPtr &tape, const ValuePtr &out) {
@ -359,8 +365,8 @@ FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, bool grad_
SetOutput(weights, grad_inputs, grad_weights);
// Replace Parameter of primal funcgraph with parameter of tape_;
ReplacePrimalParameter(weights, has_sens_arg);
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
auto save_graphs_flg = MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
if (save_graphs_flg) {
DumpIR("before_final_opt.ir", tape_);
}
return tape_;
@ -645,7 +651,7 @@ bool KPynativeCellImpl::BuildAdjoint(const CNodePtr &cnode, const ValuePtrList &
FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &cnode, const ValuePtrList &op_args,
const ValuePtr &out) {
auto optimized_bprop_fg =
pipeline::PrimBpropOptimizer::GetPrimBpropOptimizerInst().OptimizeBPropFuncGraph(bprop_fg, cnode, op_args, out);
PrimBpropOptimizer::GetPrimBpropOptimizerInst().OptimizeBPropFuncGraph(bprop_fg, cnode, op_args, out);
return optimized_bprop_fg;
}

View File

@ -16,11 +16,11 @@
#include <memory>
#include "ir/func_graph_cloner.h"
#include "pipeline/jit/prim_bprop_optimizer.h"
#include "frontend/optimizer/ad/prim_bprop_optimizer.h"
#include "pipeline/jit/pass.h"
namespace mindspore {
namespace pipeline {
namespace ad {
void PrimBpropOptGraphLevel2Info::TryFreeArgsValue(const ValuePtrList &op_args, const ValuePtr &out) {
// args_value_using_info_ contains out
if (args_value_using_info_.size() != op_args.size() + 1) {
@ -231,9 +231,13 @@ FuncGraphPtr PrimBpropOptimizer::GetOptBpropFromCache(const FuncGraphPtr &bprop_
// do step2 opt
auto new_abs_list = AddOutToAbsList(out, abs_list);
level_2_graph_info = PrimBpropOptStep2(level_1_graph, new_abs_list);
level_1_graph_info->graph_level_2_cache_[abs_list] = level_2_graph_info;
level_2_graph_info->TryFreeArgsValue(op_args, out);
return BasicClone(level_2_graph_info->opt_func_graph());
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_GRAD_CACHE);
if (enable_grad_cache) {
level_1_graph_info->graph_level_2_cache_[abs_list] = level_2_graph_info;
return BasicClone(level_2_graph_info->opt_func_graph());
}
return level_2_graph_info->opt_func_graph();
}
FuncGraphPtr PrimBpropOptimizer::GenSpecOptBprop(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args,
@ -256,8 +260,8 @@ FuncGraphPtr PrimBpropOptimizer::GenSpecOptBprop(const FuncGraphPtr &bprop_fg, c
auto new_abs_list = AddOutToAbsList(out, abs_list);
auto level_2_graph_info = PrimBpropOptStep2(level_1_graph_info->opt_func_graph_, new_abs_list);
level_2_graph_info->TryFreeArgsValue(op_args, out);
if (!hook_flg) {
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_GRAD_CACHE);
if (!hook_flg && enable_grad_cache) {
tuple_list_bprop_cache_[std::pair(prim, abs_list)] = BasicClone(level_2_graph_info->opt_func_graph());
}
return level_2_graph_info->opt_func_graph();
@ -303,7 +307,7 @@ PrimBpropOptGraphLevel2InfoPtr PrimBpropOptimizer::PrimBpropOptStep2(
return level_2_graph_info;
}
FuncGraphPtr PrimBpropOptimizer::BpropGraphFinalOpt(const ResourcePtr &res) const {
FuncGraphPtr PrimBpropOptimizer::BpropGraphFinalOpt(const pipeline::ResourcePtr &res) const {
MS_EXCEPTION_IF_NULL(res);
auto after_opt_bg = BpropGraphFinalOptPass(res);
return after_opt_bg;
@ -368,5 +372,5 @@ abstract::AbstractBasePtrList PrimBpropOptimizer::AddOutToAbsList(const ValuePtr
(void)new_abs_list.emplace_back(out_abs);
return new_abs_list;
}
} // namespace pipeline
} // namespace ad
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PRIM_BPROP_OPTIMIZER_H
#define MINDSPORE_CCSRC_PIPELINE_JIT_PRIM_BPROP_OPTIMIZER_H
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_PRIM_BPROP_OPTIMIZER_H
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_PRIM_BPROP_OPTIMIZER_H
#include <vector>
#include <utility>
@ -27,7 +27,7 @@
#include "pipeline/jit/resource.h"
namespace mindspore {
namespace pipeline {
namespace ad {
struct PrimBpropOptGraphInfo;
class PrimBpropOptGraphLevel2Info;
@ -144,7 +144,7 @@ class PrimBpropOptimizer {
const ValuePtr &out);
// do inline opt for final bprop graph
FuncGraphPtr BpropGraphFinalOpt(const ResourcePtr &res) const;
FuncGraphPtr BpropGraphFinalOpt(const pipeline::ResourcePtr &res) const;
private:
PrimBpropOptimizer() = default;
@ -179,8 +179,7 @@ class PrimBpropOptimizer {
PrimBpropCache prim_bprop_cache_;
PrimTupleListCache tuple_list_bprop_cache_;
};
} // namespace pipeline
} // namespace ad
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_PRIM_BPROP_OPTIMIZER_H
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_PRIM_BPROP_OPTIMIZER_H

View File

@ -8,7 +8,6 @@ file(GLOB_RECURSE _PIPELINE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"pipeline_split.cc"
"parse/*.cc"
"static_analysis/*.cc"
"prim_bprop_optimizer.cc"
)

View File

@ -51,7 +51,7 @@
#include "utils/shape_utils.h"
#include "utils/info.h"
#include "load_mindir/load_model.h"
#include "pipeline/jit/prim_bprop_optimizer.h"
#include "frontend/optimizer/ad/prim_bprop_optimizer.h"
#include "runtime/hardware/device_context_manager.h"
#include "utils/crypto.h"
@ -1320,7 +1320,7 @@ void ClearResAtexit() {
device::DeviceContextManager::GetInstance().ClearDeviceContexts();
ad::g_k_prims.clear();
ad::ClearKPynativeCellStaticRes();
PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
abstract::ClearPrimEvaluatorMap();
pipeline::GetMethodMap().clear();

View File

@ -59,7 +59,7 @@
#include "pipeline/jit/resource.h"
#include "pipeline/jit/pass.h"
#include "frontend/parallel/context.h"
#include "pipeline/jit/prim_bprop_optimizer.h"
#include "frontend/optimizer/ad/prim_bprop_optimizer.h"
#ifdef ENABLE_GE
#include "pipeline/pynative/pynative_execute_ge.h"
@ -2447,7 +2447,7 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const prim::GradOperationPtr &grad, con
auto manager = resource->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(bprop_graph);
auto optimized_bg = pipeline::PrimBpropOptimizer::GetPrimBpropOptimizerInst().BpropGraphFinalOpt(resource);
auto optimized_bg = ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().BpropGraphFinalOpt(resource);
if (cell_stack_.empty()) {
need_renormalize_ = false;

View File

@ -100,7 +100,8 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
.value("graph_kernel_flags", MsCtxParam::MS_CTX_GRAPH_KERNEL_FLAGS)
.value("grad_for_scalar", MsCtxParam::MS_CTX_GRAD_FOR_SCALAR)
.value("save_compile_cache", MsCtxParam::MS_CTX_SAVE_COMPILE_CACHE)
.value("load_compile_cache", MsCtxParam::MS_CTX_LOAD_COMPILE_CACHE);
.value("load_compile_cache", MsCtxParam::MS_CTX_LOAD_COMPILE_CACHE)
.value("enable_grad_cache", MsCtxParam::MS_CTX_ENABLE_GRAD_CACHE);
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext")
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
.def("get_param", &mindspore::MsCtxGetParameter, "Get value of specified parameter.")

View File

@ -514,7 +514,7 @@ def _check_target_specific_cfgs(device, arg_key):
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str,
enable_sparse=bool, max_call_depth=int, env_config_path=str, graph_kernel_flags=str,
save_compile_cache=bool, load_compile_cache=bool, grad_for_scalar=bool)
save_compile_cache=bool, load_compile_cache=bool, grad_for_scalar=bool, enable_grad_cache=bool)
def set_context(**kwargs):
"""
Set context for running environment.
@ -552,6 +552,7 @@ def set_context(**kwargs):
grad_for_scalar
save_compile_cache
load_compile_cache
enable_grad_cache
=========================== =========================== =================
Args:
@ -663,6 +664,9 @@ def set_context(**kwargs):
you should make sure the network has not been changed since the last execution. By now, we have
not support automatically checking the changes yet. Default: False.
This is an experimental prototype that is subject to change and/or deletion.
enable_grad_cache (bool): Whether to use cache for grad, default True.
The cache will cost memory for every compiled graph.
If the input data shape is uncertian, advised to disable the cache for save memory.
Raises:
ValueError: If input key is not an attribute in context.
@ -686,6 +690,7 @@ def set_context(**kwargs):
>>> context.set_context(print_file_path="print.pb")
>>> context.set_context(max_call_depth=80)
>>> context.set_context(env_config_path="./env_config.json")
>>> context.set_context(enable_grad_cache=True)
"""
ctx = _context()
# set device target first

View File

@ -88,6 +88,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
set_param<bool>(MS_CTX_LOAD_COMPILE_CACHE, false);
set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
set_param<bool>(MS_CTX_ALREADY_SET_ENABLE_MINDRT, false);
set_param<bool>(MS_CTX_ENABLE_GRAD_CACHE, true);
backend_policy_ = policy_map_[policy];
}

View File

@ -90,6 +90,7 @@ enum MsCtxParam : unsigned {
MS_CTX_LOAD_COMPILE_CACHE,
MS_CTX_ENABLE_MINDRT,
MS_CTX_ALREADY_SET_ENABLE_MINDRT,
MS_CTX_ENABLE_GRAD_CACHE,
MS_CTX_TYPE_BOOL_END,
// parameter of type int