forked from mindspore-Ecosystem/mindspore
!20787 Memory leak fix.
Merge pull request !20787 from zhangzhaoju/master_leak_fix
This commit is contained in:
commit
50d54a7482
|
@ -25,7 +25,7 @@
|
|||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "ir/anf.h"
|
||||
#include "pipeline/jit/prim_bprop_optimizer.h"
|
||||
#include "frontend/optimizer/ad/prim_bprop_optimizer.h"
|
||||
#include "frontend/optimizer/ad/adjoint.h"
|
||||
#include "frontend/optimizer/ad/dfunctor.h"
|
||||
#include "frontend/optimizer/ad/kpynative.h"
|
||||
|
@ -90,8 +90,11 @@ FuncGraphPtr GetZerosLike(const abstract::AbstractBasePtrList &args_spec) {
|
|||
MS_EXCEPTION_IF_NULL(specialized_zeros_like_fg);
|
||||
auto opted_zeros_like_fg = ZerosLikePrimOptPass(resource);
|
||||
MS_EXCEPTION_IF_NULL(opted_zeros_like_fg);
|
||||
zeros_like_funcgraph_cache[args_spec] = opted_zeros_like_fg;
|
||||
return BasicClone(opted_zeros_like_fg);
|
||||
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_GRAD_CACHE);
|
||||
if (enable_grad_cache) {
|
||||
zeros_like_funcgraph_cache[args_spec] = BasicClone(opted_zeros_like_fg);
|
||||
}
|
||||
return opted_zeros_like_fg;
|
||||
}
|
||||
|
||||
FuncGraphPtr GetHyperAdd(const abstract::AbstractBasePtrList &args_spec) {
|
||||
|
@ -146,8 +149,11 @@ FuncGraphPtr GetOnesLike(const abstract::AbstractBasePtrList &args_spec) {
|
|||
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
|
||||
auto specialized_ones_like_fg = pipeline::Renormalize(resource, ones_like_fg, args_spec);
|
||||
MS_EXCEPTION_IF_NULL(specialized_ones_like_fg);
|
||||
ones_like_funcgraph_cache[args_spec] = specialized_ones_like_fg;
|
||||
return BasicClone(specialized_ones_like_fg);
|
||||
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_GRAD_CACHE);
|
||||
if (enable_grad_cache) {
|
||||
ones_like_funcgraph_cache[args_spec] = BasicClone(specialized_ones_like_fg);
|
||||
}
|
||||
return specialized_ones_like_fg;
|
||||
}
|
||||
|
||||
AnfNodePtr BuildOnesLikeValue(const FuncGraphPtr &tape, const ValuePtr &out) {
|
||||
|
@ -359,8 +365,8 @@ FuncGraphPtr KPynativeCellImpl::Finish(const AnfNodePtrList &weights, bool grad_
|
|||
SetOutput(weights, grad_inputs, grad_weights);
|
||||
// Replace Parameter of primal funcgraph with parameter of tape_;
|
||||
ReplacePrimalParameter(weights, has_sens_arg);
|
||||
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
auto save_graphs_flg = MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
if (save_graphs_flg) {
|
||||
DumpIR("before_final_opt.ir", tape_);
|
||||
}
|
||||
return tape_;
|
||||
|
@ -645,7 +651,7 @@ bool KPynativeCellImpl::BuildAdjoint(const CNodePtr &cnode, const ValuePtrList &
|
|||
FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &cnode, const ValuePtrList &op_args,
|
||||
const ValuePtr &out) {
|
||||
auto optimized_bprop_fg =
|
||||
pipeline::PrimBpropOptimizer::GetPrimBpropOptimizerInst().OptimizeBPropFuncGraph(bprop_fg, cnode, op_args, out);
|
||||
PrimBpropOptimizer::GetPrimBpropOptimizerInst().OptimizeBPropFuncGraph(bprop_fg, cnode, op_args, out);
|
||||
return optimized_bprop_fg;
|
||||
}
|
||||
|
||||
|
|
|
@ -16,11 +16,11 @@
|
|||
|
||||
#include <memory>
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "pipeline/jit/prim_bprop_optimizer.h"
|
||||
#include "frontend/optimizer/ad/prim_bprop_optimizer.h"
|
||||
#include "pipeline/jit/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace pipeline {
|
||||
namespace ad {
|
||||
void PrimBpropOptGraphLevel2Info::TryFreeArgsValue(const ValuePtrList &op_args, const ValuePtr &out) {
|
||||
// args_value_using_info_ contains out
|
||||
if (args_value_using_info_.size() != op_args.size() + 1) {
|
||||
|
@ -231,9 +231,13 @@ FuncGraphPtr PrimBpropOptimizer::GetOptBpropFromCache(const FuncGraphPtr &bprop_
|
|||
// do step2 opt
|
||||
auto new_abs_list = AddOutToAbsList(out, abs_list);
|
||||
level_2_graph_info = PrimBpropOptStep2(level_1_graph, new_abs_list);
|
||||
level_1_graph_info->graph_level_2_cache_[abs_list] = level_2_graph_info;
|
||||
level_2_graph_info->TryFreeArgsValue(op_args, out);
|
||||
return BasicClone(level_2_graph_info->opt_func_graph());
|
||||
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_GRAD_CACHE);
|
||||
if (enable_grad_cache) {
|
||||
level_1_graph_info->graph_level_2_cache_[abs_list] = level_2_graph_info;
|
||||
return BasicClone(level_2_graph_info->opt_func_graph());
|
||||
}
|
||||
return level_2_graph_info->opt_func_graph();
|
||||
}
|
||||
|
||||
FuncGraphPtr PrimBpropOptimizer::GenSpecOptBprop(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args,
|
||||
|
@ -256,8 +260,8 @@ FuncGraphPtr PrimBpropOptimizer::GenSpecOptBprop(const FuncGraphPtr &bprop_fg, c
|
|||
auto new_abs_list = AddOutToAbsList(out, abs_list);
|
||||
auto level_2_graph_info = PrimBpropOptStep2(level_1_graph_info->opt_func_graph_, new_abs_list);
|
||||
level_2_graph_info->TryFreeArgsValue(op_args, out);
|
||||
|
||||
if (!hook_flg) {
|
||||
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_GRAD_CACHE);
|
||||
if (!hook_flg && enable_grad_cache) {
|
||||
tuple_list_bprop_cache_[std::pair(prim, abs_list)] = BasicClone(level_2_graph_info->opt_func_graph());
|
||||
}
|
||||
return level_2_graph_info->opt_func_graph();
|
||||
|
@ -303,7 +307,7 @@ PrimBpropOptGraphLevel2InfoPtr PrimBpropOptimizer::PrimBpropOptStep2(
|
|||
return level_2_graph_info;
|
||||
}
|
||||
|
||||
FuncGraphPtr PrimBpropOptimizer::BpropGraphFinalOpt(const ResourcePtr &res) const {
|
||||
FuncGraphPtr PrimBpropOptimizer::BpropGraphFinalOpt(const pipeline::ResourcePtr &res) const {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
auto after_opt_bg = BpropGraphFinalOptPass(res);
|
||||
return after_opt_bg;
|
||||
|
@ -368,5 +372,5 @@ abstract::AbstractBasePtrList PrimBpropOptimizer::AddOutToAbsList(const ValuePtr
|
|||
(void)new_abs_list.emplace_back(out_abs);
|
||||
return new_abs_list;
|
||||
}
|
||||
} // namespace pipeline
|
||||
} // namespace ad
|
||||
} // namespace mindspore
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PRIM_BPROP_OPTIMIZER_H
|
||||
#define MINDSPORE_CCSRC_PIPELINE_JIT_PRIM_BPROP_OPTIMIZER_H
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_PRIM_BPROP_OPTIMIZER_H
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_PRIM_BPROP_OPTIMIZER_H
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
@ -27,7 +27,7 @@
|
|||
#include "pipeline/jit/resource.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace pipeline {
|
||||
namespace ad {
|
||||
struct PrimBpropOptGraphInfo;
|
||||
|
||||
class PrimBpropOptGraphLevel2Info;
|
||||
|
@ -144,7 +144,7 @@ class PrimBpropOptimizer {
|
|||
const ValuePtr &out);
|
||||
|
||||
// do inline opt for final bprop graph
|
||||
FuncGraphPtr BpropGraphFinalOpt(const ResourcePtr &res) const;
|
||||
FuncGraphPtr BpropGraphFinalOpt(const pipeline::ResourcePtr &res) const;
|
||||
|
||||
private:
|
||||
PrimBpropOptimizer() = default;
|
||||
|
@ -179,8 +179,7 @@ class PrimBpropOptimizer {
|
|||
PrimBpropCache prim_bprop_cache_;
|
||||
PrimTupleListCache tuple_list_bprop_cache_;
|
||||
};
|
||||
|
||||
} // namespace pipeline
|
||||
} // namespace ad
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_PRIM_BPROP_OPTIMIZER_H
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_PRIM_BPROP_OPTIMIZER_H
|
|
@ -8,7 +8,6 @@ file(GLOB_RECURSE _PIPELINE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"pipeline_split.cc"
|
||||
"parse/*.cc"
|
||||
"static_analysis/*.cc"
|
||||
"prim_bprop_optimizer.cc"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@
|
|||
#include "utils/shape_utils.h"
|
||||
#include "utils/info.h"
|
||||
#include "load_mindir/load_model.h"
|
||||
#include "pipeline/jit/prim_bprop_optimizer.h"
|
||||
#include "frontend/optimizer/ad/prim_bprop_optimizer.h"
|
||||
#include "runtime/hardware/device_context_manager.h"
|
||||
#include "utils/crypto.h"
|
||||
|
||||
|
@ -1320,7 +1320,7 @@ void ClearResAtexit() {
|
|||
device::DeviceContextManager::GetInstance().ClearDeviceContexts();
|
||||
ad::g_k_prims.clear();
|
||||
ad::ClearKPynativeCellStaticRes();
|
||||
PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
|
||||
ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
|
||||
|
||||
abstract::ClearPrimEvaluatorMap();
|
||||
pipeline::GetMethodMap().clear();
|
||||
|
|
|
@ -59,7 +59,7 @@
|
|||
#include "pipeline/jit/resource.h"
|
||||
#include "pipeline/jit/pass.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "pipeline/jit/prim_bprop_optimizer.h"
|
||||
#include "frontend/optimizer/ad/prim_bprop_optimizer.h"
|
||||
|
||||
#ifdef ENABLE_GE
|
||||
#include "pipeline/pynative/pynative_execute_ge.h"
|
||||
|
@ -2447,7 +2447,7 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const prim::GradOperationPtr &grad, con
|
|||
auto manager = resource->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->AddFuncGraph(bprop_graph);
|
||||
auto optimized_bg = pipeline::PrimBpropOptimizer::GetPrimBpropOptimizerInst().BpropGraphFinalOpt(resource);
|
||||
auto optimized_bg = ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().BpropGraphFinalOpt(resource);
|
||||
|
||||
if (cell_stack_.empty()) {
|
||||
need_renormalize_ = false;
|
||||
|
|
|
@ -100,7 +100,8 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
|
|||
.value("graph_kernel_flags", MsCtxParam::MS_CTX_GRAPH_KERNEL_FLAGS)
|
||||
.value("grad_for_scalar", MsCtxParam::MS_CTX_GRAD_FOR_SCALAR)
|
||||
.value("save_compile_cache", MsCtxParam::MS_CTX_SAVE_COMPILE_CACHE)
|
||||
.value("load_compile_cache", MsCtxParam::MS_CTX_LOAD_COMPILE_CACHE);
|
||||
.value("load_compile_cache", MsCtxParam::MS_CTX_LOAD_COMPILE_CACHE)
|
||||
.value("enable_grad_cache", MsCtxParam::MS_CTX_ENABLE_GRAD_CACHE);
|
||||
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext")
|
||||
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
|
||||
.def("get_param", &mindspore::MsCtxGetParameter, "Get value of specified parameter.")
|
||||
|
|
|
@ -514,7 +514,7 @@ def _check_target_specific_cfgs(device, arg_key):
|
|||
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
|
||||
enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str,
|
||||
enable_sparse=bool, max_call_depth=int, env_config_path=str, graph_kernel_flags=str,
|
||||
save_compile_cache=bool, load_compile_cache=bool, grad_for_scalar=bool)
|
||||
save_compile_cache=bool, load_compile_cache=bool, grad_for_scalar=bool, enable_grad_cache=bool)
|
||||
def set_context(**kwargs):
|
||||
"""
|
||||
Set context for running environment.
|
||||
|
@ -552,6 +552,7 @@ def set_context(**kwargs):
|
|||
grad_for_scalar
|
||||
save_compile_cache
|
||||
load_compile_cache
|
||||
enable_grad_cache
|
||||
=========================== =========================== =================
|
||||
|
||||
Args:
|
||||
|
@ -663,6 +664,9 @@ def set_context(**kwargs):
|
|||
you should make sure the network has not been changed since the last execution. By now, we have
|
||||
not support automatically checking the changes yet. Default: False.
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
enable_grad_cache (bool): Whether to use cache for grad, default True.
|
||||
The cache will cost memory for every compiled graph.
|
||||
If the input data shape is uncertian, advised to disable the cache for save memory.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not an attribute in context.
|
||||
|
@ -686,6 +690,7 @@ def set_context(**kwargs):
|
|||
>>> context.set_context(print_file_path="print.pb")
|
||||
>>> context.set_context(max_call_depth=80)
|
||||
>>> context.set_context(env_config_path="./env_config.json")
|
||||
>>> context.set_context(enable_grad_cache=True)
|
||||
"""
|
||||
ctx = _context()
|
||||
# set device target first
|
||||
|
|
|
@ -88,6 +88,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
|||
set_param<bool>(MS_CTX_LOAD_COMPILE_CACHE, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
|
||||
set_param<bool>(MS_CTX_ALREADY_SET_ENABLE_MINDRT, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_GRAD_CACHE, true);
|
||||
|
||||
backend_policy_ = policy_map_[policy];
|
||||
}
|
||||
|
|
|
@ -90,6 +90,7 @@ enum MsCtxParam : unsigned {
|
|||
MS_CTX_LOAD_COMPILE_CACHE,
|
||||
MS_CTX_ENABLE_MINDRT,
|
||||
MS_CTX_ALREADY_SET_ENABLE_MINDRT,
|
||||
MS_CTX_ENABLE_GRAD_CACHE,
|
||||
MS_CTX_TYPE_BOOL_END,
|
||||
|
||||
// parameter of type int
|
||||
|
|
Loading…
Reference in New Issue