!22000 Fix Memory leak in pynative

Merge pull request !22000 from zjun/fix_memory_leak
This commit is contained in:
i-robot 2021-08-21 09:13:43 +00:00 committed by Gitee
commit 80b6e4debc
16 changed files with 116 additions and 104 deletions

View File

@ -628,15 +628,12 @@ void AscendSession::RunOpHardwareOptimize(const std::shared_ptr<session::KernelG
MS_LOG(INFO) << "HardwareOptimize Finish";
}
bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const {
return run_op_graphs_.find(graph_info) != run_op_graphs_.end();
}
void AscendSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) {
if (GraphCacheExist(graph_info)) {
return;
KernelGraphPtr AscendSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) {
auto it = run_op_graphs_.find(graph_info);
if (it != run_op_graphs_.end()) {
return it->second;
}
const auto &graph = PreBuildOp(op_run_info, input_tensors, tensors_mask);
@ -646,7 +643,11 @@ void AscendSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &g
// build kernel
RunOpAdjustKernel(graph);
BuildKernel(graph);
run_op_graphs_[graph_info] = graph;
auto enable_op_graph_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
if (enable_op_graph_cache) {
run_op_graphs_[graph_info] = graph;
}
return graph;
}
void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
@ -654,7 +655,7 @@ void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_inf
const std::vector<int64_t> &tensors_mask) {
MS_EXCEPTION_IF_NULL(input_tensors);
MS_EXCEPTION_IF_NULL(op_run_info);
BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
const auto &graph = BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
EraseValueNodeTensor(tensors_mask, input_tensors);
// wait for allreduce
@ -663,9 +664,6 @@ void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_inf
tensor->WaitDevice();
}
}
// Run op
auto graph = run_op_graphs_[graph_info];
MS_EXCEPTION_IF_NULL(graph);
// malloc mem
RunOpRemoveNopNode(graph);
RunOpMemoryAlloc(*input_tensors, graph.get());
@ -792,7 +790,10 @@ void AscendSession::BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfN
// Record single op graphs in run_op_graphs_ so that these graphs can be reused in BuildOpImpl
for (const auto &graph_item : single_op_graphs) {
RunOpMemoryClear(graph_item.first.get());
run_op_graphs_[graph_item.second] = graph_item.first;
auto enable_op_graph_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
if (enable_op_graph_cache) {
run_op_graphs_[graph_item.second] = graph_item.first;
}
MS_LOG(DEBUG) << "Pre build op finished, graph info: " << graph_item.second;
}
built_graph_id_.insert(graph_id);

View File

@ -57,9 +57,9 @@ class AscendSession : public SessionBasic {
VectorRef *const outputs) override;
void ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) override;
void BuildGraphImpl(GraphId) override;
void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) override;
KernelGraphPtr BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) override;
void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors,
VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> &parameter_index,
@ -104,8 +104,6 @@ class AscendSession : public SessionBasic {
const std::vector<GraphId> &GetGraphOrder(GraphId final_graph_id) const;
// get graph order type vector by graph id
const std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id) const;
// check if graph cache exist
bool GraphCacheExist(const GraphInfo &graph_info) const;
// sync initial tensors' data to device
void SyncInitialTenosrToDevice();
void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph);

View File

@ -212,21 +212,27 @@ void CPUSession::ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph)
}
}
void CPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) {
KernelGraphPtr CPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) {
// Check if the graph cache exists.
if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) {
return;
auto it = run_op_graphs_.find(graph_info);
if (it != run_op_graphs_.end()) {
return it->second;
}
// Prepare the graph
auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
const auto &kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
MS_EXCEPTION_IF_NULL(kernel_graph);
SetKernelInfo(kernel_graph.get());
Optimize(kernel_graph);
BuildKernel(kernel_graph.get());
ProcessCast(kernel_graph);
run_op_graphs_[graph_info] = kernel_graph;
auto enable_op_graph_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
if (enable_op_graph_cache) {
run_op_graphs_[graph_info] = kernel_graph;
}
return kernel_graph;
}
void CPUSession::SetOutputFlags(const VectorRef &base_ref) {
@ -260,12 +266,8 @@ void CPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
const std::vector<int64_t> &tensors_mask) {
MS_EXCEPTION_IF_NULL(input_tensors);
MS_EXCEPTION_IF_NULL(op_run_info);
BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
const auto &kernel_graph = BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
EraseValueNodeTensor(tensors_mask, input_tensors);
auto kernel_graph = run_op_graphs_[graph_info];
MS_EXCEPTION_IF_NULL(kernel_graph);
// Remove reorder after PS feature finish adapting push/pull in auto_monad.
auto execution_order = kernel_graph->execution_order();
Reorder(&execution_order);

View File

@ -43,9 +43,9 @@ class CPUSession : public SessionBasic {
void ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) override;
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) override;
void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph);
void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) override;
KernelGraphPtr BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) override;
void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors,
VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,

View File

@ -601,16 +601,17 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const
}
}
void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) {
KernelGraphPtr GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) {
// Check if the graph cache exists.
if (run_op_graphs_.find(graph_info) != run_op_graphs_.end() &&
kOpCacheBlackList.find(op_run_info.op_name) == kOpCacheBlackList.end()) {
return;
auto it = run_op_graphs_.find(graph_info);
if (it != run_op_graphs_.end() && kOpCacheBlackList.find(op_run_info.op_name) == kOpCacheBlackList.end()) {
return it->second;
}
// Prepare the graph
auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
const auto &kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
MS_EXCEPTION_IF_NULL(kernel_graph);
RunOpOptimize(kernel_graph);
SelectKernel(kernel_graph);
@ -618,7 +619,11 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap
StartKernelRT();
RunOpHideNopNode(kernel_graph);
BuildKernel(kernel_graph);
run_op_graphs_[graph_info] = kernel_graph;
auto enable_op_graph_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
if (enable_op_graph_cache) {
run_op_graphs_[graph_info] = kernel_graph;
}
return kernel_graph;
}
void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
@ -626,7 +631,7 @@ void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
const std::vector<int64_t> &tensors_mask) {
MS_EXCEPTION_IF_NULL(input_tensors);
MS_EXCEPTION_IF_NULL(op_run_info);
BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
const auto &kernel_graph = BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
EraseValueNodeTensor(tensors_mask, input_tensors);
// wait for allreduce
for (auto &tensor : *input_tensors) {
@ -636,7 +641,6 @@ void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
}
}
// run op
auto kernel_graph = run_op_graphs_[graph_info];
MS_EXCEPTION_IF_NULL(kernel_graph);
RunOpRemoveNopNode(kernel_graph);
RunOpAllocateMemory(*input_tensors, kernel_graph.get());

View File

@ -45,9 +45,9 @@ class GPUSession : public SessionBasic {
void PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *const outputs) override;
void ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) override;
void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) override;
KernelGraphPtr BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) override;
void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors,
VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) override;

View File

@ -217,9 +217,11 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
const std::vector<tensor::TensorPtr> &inputs, VectorRef *const outputs) {}
virtual void ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) {}
void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
virtual void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) {}
virtual KernelGraphPtr BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) {
return nullptr;
}
virtual void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
const std::vector<int64_t> &tensors_mask) {}

View File

@ -90,7 +90,7 @@ FuncGraphPtr GetZerosLike(const abstract::AbstractBasePtrList &args_spec) {
MS_EXCEPTION_IF_NULL(specialized_zeros_like_fg);
auto opted_zeros_like_fg = ZerosLikePrimOptPass(resource);
MS_EXCEPTION_IF_NULL(opted_zeros_like_fg);
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_GRAD_CACHE);
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
if (enable_grad_cache) {
zeros_like_funcgraph_cache[args_spec] = BasicClone(opted_zeros_like_fg);
}
@ -149,7 +149,7 @@ FuncGraphPtr GetOnesLike(const abstract::AbstractBasePtrList &args_spec) {
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
auto specialized_ones_like_fg = pipeline::Renormalize(resource, ones_like_fg, args_spec);
MS_EXCEPTION_IF_NULL(specialized_ones_like_fg);
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_GRAD_CACHE);
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
if (enable_grad_cache) {
ones_like_funcgraph_cache[args_spec] = BasicClone(specialized_ones_like_fg);
}

View File

@ -232,7 +232,7 @@ FuncGraphPtr PrimBpropOptimizer::GetOptBpropFromCache(const FuncGraphPtr &bprop_
auto new_abs_list = AddOutToAbsList(out, abs_list);
level_2_graph_info = PrimBpropOptStep2(level_1_graph, new_abs_list);
level_2_graph_info->TryFreeArgsValue(op_args, out);
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_GRAD_CACHE);
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
if (enable_grad_cache) {
level_1_graph_info->graph_level_2_cache_[abs_list] = level_2_graph_info;
return BasicClone(level_2_graph_info->opt_func_graph());
@ -260,7 +260,7 @@ FuncGraphPtr PrimBpropOptimizer::GenSpecOptBprop(const FuncGraphPtr &bprop_fg, c
auto new_abs_list = AddOutToAbsList(out, abs_list);
auto level_2_graph_info = PrimBpropOptStep2(level_1_graph_info->opt_func_graph_, new_abs_list);
level_2_graph_info->TryFreeArgsValue(op_args, out);
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_GRAD_CACHE);
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
if (!hook_flg && enable_grad_cache) {
tuple_list_bprop_cache_[std::pair(prim, abs_list)] = BasicClone(level_2_graph_info->opt_func_graph());
}

View File

@ -702,12 +702,6 @@ py::object GetDstType(const TypeId &type_id) {
MS_EXCEPTION_IF_NULL(value);
return py::cast(value);
}
void EnableGraphCache(bool flag) {
const auto inst = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(inst);
inst->set_param<bool>(MS_CTX_ENABLE_GRAD_CACHE, flag);
}
} // namespace
py::object RealRunOp(const py::args &args) {
@ -992,7 +986,7 @@ void ForwardExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
if (shape->IsDynamic()) {
op_exec_info->is_dynamic_shape = true;
// Dynamic shape operator in the current top cell, disable backend cache
EnableGraphCache(false);
grad()->EnableOpGraphCache(false);
}
}
@ -1016,7 +1010,7 @@ void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
}
// Add output abstract info into cache, the const value needs to infer evert step
if (!prim_cache_hit && !op_exec_info->is_dynamic_shape) {
if (grad()->enable_op_cache() && !prim_cache_hit && !op_exec_info->is_dynamic_shape) {
AbsCacheKey key{prim->name(), prim->Hash(), prim->attrs()};
auto &out = prim_abs_list_[key];
out[args_spec_list].abs = op_exec_info->abstract;
@ -1338,6 +1332,13 @@ TopCellInfoPtr GradExecutor::GetTopCell(const std::string &cell_id) const {
return nullptr;
}
void GradExecutor::EnableOpGraphCache(bool is_enable) {
enable_op_cache_ = is_enable;
const auto inst = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(inst);
inst->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, is_enable);
}
void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const py::object &ret) {
if (!grad_flag_) {
MS_LOG(DEBUG) << "Grad flag is set to false, no need to record op info";
@ -1515,7 +1516,7 @@ void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_e
}
// First run top cell
if (already_run_top_cell_.find(top_cell_->cell_id()) == already_run_top_cell_.end()) {
if (already_run_top_cell_.find(top_cell_->already_run_cell_id()) == already_run_top_cell_.end()) {
MS_LOG(DEBUG) << "Top cell " << top_cell_->cell_id() << " run firstly";
if (!need_construct_graph()) {
MS_LOG(EXCEPTION) << "The cell stack is empty when running a new top cell " << top_cell_->cell_id();
@ -1523,7 +1524,7 @@ void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_e
return;
}
// Non-first run
const auto &pre_top_cell = already_run_top_cell_.at(top_cell_->cell_id());
const auto &pre_top_cell = already_run_top_cell_.at(top_cell_->already_run_cell_id());
MS_EXCEPTION_IF_NULL(pre_top_cell);
if (pre_top_cell->op_info_with_tensor_id().find(op_info) == pre_top_cell->op_info_with_tensor_id().end()) {
MS_LOG(DEBUG) << "Can not find op info " << op_info << " in op info with tensor id map. Top cell "
@ -1895,13 +1896,12 @@ void GradExecutor::ClearCellRes(const std::string &cell_id) {
}
// clear when cell destruction
for (auto it = top_cell_list_.begin(); it != top_cell_list_.end();) {
auto top_cell_id = (*it)->cell_id();
const auto &top_cell_id = (*it)->cell_id();
const auto &alreay_top_cell_id = (*it)->already_run_cell_id();
if (IsCellObjIdEq(cell_id, top_cell_id)) {
(*it)->Clear();
it = top_cell_list_.erase(it);
if (already_run_top_cell_.find(top_cell_id) != already_run_top_cell_.end()) {
(void)already_run_top_cell_.erase(top_cell_id);
}
(void)already_run_top_cell_.erase(alreay_top_cell_id);
MS_LOG(DEBUG) << "Clear top cell resource. Top cell id " << top_cell_id;
continue;
}
@ -1952,7 +1952,7 @@ void GradExecutor::HandleInputArgsForTopCell(const py::args &args, bool is_bprop
}
// Convert input args to parameters for top cell graph in construct.
std::vector<ValuePtr> input_param_values;
py::list only_tensors = FilterTensorArgs(args);
const auto &only_tensors = FilterTensorArgs(args);
auto df_builder = GetDfbuilder(top_cell_->cell_id());
MS_EXCEPTION_IF_NULL(df_builder);
for (size_t i = 0; i < only_tensors.size(); ++i) {
@ -2017,11 +2017,18 @@ void GradExecutor::InitResourceAndDfBuilder(const std::string &cell_id, const py
void GradExecutor::NewGraphInner(py::object *ret, const py::object &cell, const py::args &args) {
MS_EXCEPTION_IF_NULL(ret);
auto cell_id = GetCellId(cell, args);
const auto &cell_id = GetCellId(cell, args);
MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id;
if (top_cell_ != nullptr && cell_stack_.empty()) {
// Already run top cell need distinguish high order; high order add "0" otherwise "1"
std::string already_run_cell_id;
if (IsNestedGrad()) {
already_run_cell_id = cell_id + "0";
} else {
already_run_cell_id = cell_id + "1";
}
// Whether it is top and has been run
auto top_it = already_run_top_cell_.find(cell_id);
auto top_it = already_run_top_cell_.find(already_run_cell_id);
if (top_it != already_run_top_cell_.end()) {
// Top cell forward run.
const auto &pre_top_cell = top_it->second;
@ -2032,8 +2039,8 @@ void GradExecutor::NewGraphInner(py::object *ret, const py::object &cell, const
set_top_cell(pre_top_cell);
return;
}
} else if ((top_cell()->IsSubCell(cell_id) && !IsCellObjIdEq(cell_id, check_graph_cell_id_)) ||
GetHighOrderStackSize() >= 1) {
} else if ((top_cell()->IsSubCell(cell_id) || GetHighOrderStackSize() >= 1) &&
!IsCellObjIdEq(cell_id, check_graph_cell_id_)) {
// Sub cell ( or may be a temporary cell, but must be non top) forward run in cache process.
MS_LOG(DEBUG) << "Sub cell no need to run NewGraphInner again";
return;
@ -2069,13 +2076,11 @@ void GradExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args,
// The number of top cell exceeds MAX_TOP_CELL_COUNTS, delete the last one to keep the maximum length of the list,
// disable backend cache
if (top_cell_list_.size() >= MAX_TOP_CELL_COUNTS) {
EnableGraphCache(false);
EnableOpGraphCache(false);
const auto last_top_cell = top_cell_list_.back();
top_cell_list_.pop_back();
last_top_cell->Clear();
if (already_run_top_cell_.find(last_top_cell->cell_id()) != already_run_top_cell_.end()) {
(void)already_run_top_cell_.erase(last_top_cell->cell_id());
}
(void)already_run_top_cell_.erase(last_top_cell->already_run_cell_id());
}
// Create top cell
curr_g_ = std::make_shared<FuncGraph>();
@ -2535,16 +2540,16 @@ py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::a
void GradExecutor::CheckNeedCompileGraph() {
auto new_top_cell = top_cell();
std::string top_cell_id = new_top_cell->cell_id();
// update top cell by current cell op info
if (already_run_top_cell_.find(top_cell_id) == already_run_top_cell_.end()) {
MS_LOG(DEBUG) << "Top cell " << top_cell_id << " has never been ran, need compile graph";
already_run_top_cell_[top_cell_id] = new_top_cell;
const auto &already_top_cell_id = new_top_cell->already_run_cell_id();
// Update top cell by current cell op info
if (already_run_top_cell_.find(already_top_cell_id) == already_run_top_cell_.end()) {
MS_LOG(DEBUG) << "Top cell " << new_top_cell->cell_id() << " has never been ran, need compile graph";
already_run_top_cell_[already_top_cell_id] = new_top_cell;
return;
}
MS_LOG(DEBUG) << "Top cell " << top_cell_id << " has been ran";
auto pre_top_cell = already_run_top_cell_.at(top_cell_id);
MS_LOG(DEBUG) << "Top cell " << new_top_cell->cell_id() << " has been ran";
auto pre_top_cell = already_run_top_cell_.at(already_top_cell_id);
auto pre_all_op_info = pre_top_cell->all_op_info();
auto new_all_op_info = new_top_cell->all_op_info();
MS_LOG(DEBUG) << "Pre all op info : " << pre_all_op_info;
@ -2553,14 +2558,14 @@ void GradExecutor::CheckNeedCompileGraph() {
MS_LOG(DEBUG) << "The op info has been changed, need to compile graph again";
// The top cell switches exceeds MAX_TOP_CELL_COUNTS under the control flow, disable backend cache
if (top_cell_switch_counts_ >= MAX_TOP_CELL_COUNTS) {
EnableGraphCache(false);
EnableOpGraphCache(false);
} else {
// Increase top cell switches counts
++top_cell_switch_counts_;
}
EraseTopCellFromTopCellList(pre_top_cell);
pre_top_cell->Clear();
already_run_top_cell_[top_cell_id] = new_top_cell;
already_run_top_cell_[already_top_cell_id] = new_top_cell;
} else {
MS_LOG(DEBUG) << "The op info has not been changed, no need to compile graph again";
pre_top_cell->set_input_args_id(new_top_cell->input_args_id());
@ -2813,6 +2818,7 @@ void GradExecutor::ClearRes() {
grad_flag_ = false;
need_renormalize_ = false;
grad_is_running_ = false;
enable_op_cache_ = true;
top_cell_ = nullptr;
curr_g_ = nullptr;
bprop_cell_list_.clear();

View File

@ -42,7 +42,6 @@
namespace mindspore::pynative {
namespace py = pybind11;
using CellId = std::string;
using MsFunctionGradCache = std::unordered_map<std::string, std::pair<FuncGraphPtr, FuncGraphPtr>>;
using OpInfoWithTensorId = std::unordered_map<std::string, std::vector<std::string>>;
using TensorIdWithTensorObject = std::unordered_map<std::string, std::vector<tensor::TensorPtr>>;
@ -68,7 +67,8 @@ class TopCellInfo {
grad_order_(grad_order),
resource_(std::move(r)),
df_builder_(std::move(df)),
cell_id_(std::move(cellid)) {}
cell_id_(std::move(cellid)),
alread_run_cell_id_(cell_id_ + std::to_string(is_topest_)) {}
bool is_init_kpynative() const { return is_init_kpynative_; }
void set_init_kpynative(bool init) { is_init_kpynative_ = init; }
@ -90,9 +90,10 @@ class TopCellInfo {
size_t op_num() const { return op_num_; }
void set_op_num(size_t op_num) { op_num_ = op_num; }
std::string &cell_id() { return cell_id_; }
std::string &already_run_cell_id() { return alread_run_cell_id_; }
std::string &input_args_id() { return input_args_id_; }
std::string &all_op_info() { return all_op_info_; }
void set_input_args_id(const std::string &input_args_id) { input_args_id_ = std::move(input_args_id); }
void set_input_args_id(const std::string &input_args_id) { input_args_id_ = input_args_id; }
std::unordered_set<std::string> &sub_cell_list() { return sub_cell_list_; }
bool IsSubCell(const std::string &cell_id) const;
OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map() { return graph_info_map_; }
@ -124,6 +125,7 @@ class TopCellInfo {
FuncGraphPtr df_builder_{nullptr};
ad::KPynativeCellPtr k_pynative_cell_ptr_{nullptr};
std::string cell_id_;
std::string alread_run_cell_id_;
std::string input_args_id_;
std::string all_op_info_;
OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_;
@ -173,7 +175,9 @@ class GradExecutor {
TopCellInfoPtr top_cell() const;
void CheckNeedCompileGraph();
TopCellInfoPtr GetTopCell(const string &cell_id) const;
void EnableOpGraphCache(bool is_enable);
bool need_renormalize() const { return need_renormalize_; }
bool enable_op_cache() const { return enable_op_cache_; }
void set_top_cell(TopCellInfoPtr top_cell) { top_cell_ = std::move(top_cell); }
bool grad_flag() const { return grad_flag_; }
void set_grad_flag(bool flag) { grad_flag_ = flag; }
@ -242,15 +246,15 @@ class GradExecutor {
const std::vector<int64_t> &index_sequence, bool is_param = false);
void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
bool is_param = false);
void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr &param) {
void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr &param) const {
top_cell()->graph_info_map()[g]->params[id] = param;
}
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
int64_t index = -1) {
int64_t index = -1) const {
top_cell()->graph_info_map()[g]->node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
}
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
const std::vector<int64_t> &index) {
const std::vector<int64_t> &index) const {
top_cell()->graph_info_map()[g]->node_map[id] = std::make_pair(node, index);
}
void CreateMakeTupleNodeForMultiOut(const FuncGraphPtr &curr_g, const py::object &out, const std::string &out_id);
@ -259,6 +263,7 @@ class GradExecutor {
bool grad_flag_{false};
bool need_renormalize_{false};
bool grad_is_running_{false};
bool enable_op_cache_{true};
int custom_bprop_cell_count_{0};
size_t grad_order_{0};
size_t top_cell_switch_counts_{0};
@ -281,7 +286,7 @@ class GradExecutor {
// Use vector for keep order
std::vector<TopCellInfoPtr> top_cell_list_;
// Record all top cell which has been ran
std::map<CellId, TopCellInfoPtr> already_run_top_cell_;
std::unordered_map<std::string, TopCellInfoPtr> already_run_top_cell_;
// Use vector for keep order
ForwardExecutorWeakPtr forward_executor_;
};

View File

@ -101,7 +101,6 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
.value("grad_for_scalar", MsCtxParam::MS_CTX_GRAD_FOR_SCALAR)
.value("save_compile_cache", MsCtxParam::MS_CTX_SAVE_COMPILE_CACHE)
.value("load_compile_cache", MsCtxParam::MS_CTX_LOAD_COMPILE_CACHE)
.value("enable_grad_cache", MsCtxParam::MS_CTX_ENABLE_GRAD_CACHE)
.value("pynative_synchronize", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE);
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext")
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")

View File

@ -484,7 +484,7 @@ const ActorInfo &MindRTBackend::CompileGraph(const OpRunInfo &op_run_info, const
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool enable_cache = context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAD_CACHE);
bool enable_cache = context_ptr->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
auto graph_compiler_info = ConstructGraphCompilerInfo(actor_info, tensors_mask, input_tensors, !enable_cache);
const auto actor_set = runtime::GraphScheduler::GetInstance().Transform(*graph_compiler_info);
runtime::GraphScheduler::GetInstance().Schedule(actor_set);

View File

@ -516,7 +516,7 @@ def _check_target_specific_cfgs(device, arg_key):
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str,
enable_sparse=bool, max_call_depth=int, env_config_path=str, graph_kernel_flags=str,
save_compile_cache=bool, load_compile_cache=bool, grad_for_scalar=bool, enable_grad_cache=bool)
save_compile_cache=bool, load_compile_cache=bool, grad_for_scalar=bool)
def set_context(**kwargs):
"""
Set context for running environment.
@ -554,7 +554,6 @@ def set_context(**kwargs):
grad_for_scalar
save_compile_cache
load_compile_cache
enable_grad_cache
=========================== =========================== =================
Args:
@ -666,9 +665,6 @@ def set_context(**kwargs):
you should make sure the network has not been changed since the last execution. By now, we have
not support automatically checking the changes yet. Default: False.
This is an experimental prototype that is subject to change and/or deletion.
enable_grad_cache (bool): Whether to use cache for grad, default True.
The cache will cost memory for every compiled graph.
If the input data shape is uncertian, advised to disable the cache for save memory.
Raises:
ValueError: If input key is not an attribute in context.
@ -692,7 +688,6 @@ def set_context(**kwargs):
>>> context.set_context(print_file_path="print.pb")
>>> context.set_context(max_call_depth=80)
>>> context.set_context(env_config_path="./env_config.json")
>>> context.set_context(enable_grad_cache=True)
"""
ctx = _context()
# set device target first

View File

@ -89,8 +89,8 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
set_param<bool>(MS_CTX_LOAD_COMPILE_CACHE, false);
set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
set_param<bool>(MS_CTX_ALREADY_SET_ENABLE_MINDRT, false);
set_param<bool>(MS_CTX_ENABLE_GRAD_CACHE, true);
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE, false);
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, true);
backend_policy_ = policy_map_[policy];
}

View File

@ -90,8 +90,8 @@ enum MsCtxParam : unsigned {
MS_CTX_LOAD_COMPILE_CACHE,
MS_CTX_ENABLE_MINDRT,
MS_CTX_ALREADY_SET_ENABLE_MINDRT,
MS_CTX_ENABLE_GRAD_CACHE,
MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE,
MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE,
MS_CTX_TYPE_BOOL_END,
// parameter of type int