!22000 Fix Memory leak in pynative
Merge pull request !22000 from zjun/fix_memory_leak
This commit is contained in:
commit
80b6e4debc
|
@ -628,15 +628,12 @@ void AscendSession::RunOpHardwareOptimize(const std::shared_ptr<session::KernelG
|
|||
MS_LOG(INFO) << "HardwareOptimize Finish";
|
||||
}
|
||||
|
||||
bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const {
|
||||
return run_op_graphs_.find(graph_info) != run_op_graphs_.end();
|
||||
}
|
||||
|
||||
void AscendSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) {
|
||||
if (GraphCacheExist(graph_info)) {
|
||||
return;
|
||||
KernelGraphPtr AscendSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) {
|
||||
auto it = run_op_graphs_.find(graph_info);
|
||||
if (it != run_op_graphs_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
const auto &graph = PreBuildOp(op_run_info, input_tensors, tensors_mask);
|
||||
|
@ -646,7 +643,11 @@ void AscendSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &g
|
|||
// build kernel
|
||||
RunOpAdjustKernel(graph);
|
||||
BuildKernel(graph);
|
||||
run_op_graphs_[graph_info] = graph;
|
||||
auto enable_op_graph_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
|
||||
if (enable_op_graph_cache) {
|
||||
run_op_graphs_[graph_info] = graph;
|
||||
}
|
||||
return graph;
|
||||
}
|
||||
|
||||
void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
|
||||
|
@ -654,7 +655,7 @@ void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_inf
|
|||
const std::vector<int64_t> &tensors_mask) {
|
||||
MS_EXCEPTION_IF_NULL(input_tensors);
|
||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
|
||||
const auto &graph = BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
|
||||
EraseValueNodeTensor(tensors_mask, input_tensors);
|
||||
|
||||
// wait for allreduce
|
||||
|
@ -663,9 +664,6 @@ void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_inf
|
|||
tensor->WaitDevice();
|
||||
}
|
||||
}
|
||||
// Run op
|
||||
auto graph = run_op_graphs_[graph_info];
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// malloc mem
|
||||
RunOpRemoveNopNode(graph);
|
||||
RunOpMemoryAlloc(*input_tensors, graph.get());
|
||||
|
@ -792,7 +790,10 @@ void AscendSession::BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfN
|
|||
// Record single op graphs in run_op_graphs_ so that these graphs can be reused in BuildOpImpl
|
||||
for (const auto &graph_item : single_op_graphs) {
|
||||
RunOpMemoryClear(graph_item.first.get());
|
||||
run_op_graphs_[graph_item.second] = graph_item.first;
|
||||
auto enable_op_graph_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
|
||||
if (enable_op_graph_cache) {
|
||||
run_op_graphs_[graph_item.second] = graph_item.first;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Pre build op finished, graph info: " << graph_item.second;
|
||||
}
|
||||
built_graph_id_.insert(graph_id);
|
||||
|
|
|
@ -57,9 +57,9 @@ class AscendSession : public SessionBasic {
|
|||
VectorRef *const outputs) override;
|
||||
void ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) override;
|
||||
void BuildGraphImpl(GraphId) override;
|
||||
void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) override;
|
||||
KernelGraphPtr BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) override;
|
||||
void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors,
|
||||
VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
|
||||
void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> ¶meter_index,
|
||||
|
@ -104,8 +104,6 @@ class AscendSession : public SessionBasic {
|
|||
const std::vector<GraphId> &GetGraphOrder(GraphId final_graph_id) const;
|
||||
// get graph order type vector by graph id
|
||||
const std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id) const;
|
||||
// check if graph cache exist
|
||||
bool GraphCacheExist(const GraphInfo &graph_info) const;
|
||||
// sync initial tensors' data to device
|
||||
void SyncInitialTenosrToDevice();
|
||||
void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
|
|
|
@ -212,21 +212,27 @@ void CPUSession::ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph)
|
|||
}
|
||||
}
|
||||
|
||||
void CPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) {
|
||||
KernelGraphPtr CPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) {
|
||||
// Check if the graph cache exists.
|
||||
if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) {
|
||||
return;
|
||||
auto it = run_op_graphs_.find(graph_info);
|
||||
if (it != run_op_graphs_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Prepare the graph
|
||||
auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
|
||||
const auto &kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
SetKernelInfo(kernel_graph.get());
|
||||
Optimize(kernel_graph);
|
||||
BuildKernel(kernel_graph.get());
|
||||
ProcessCast(kernel_graph);
|
||||
run_op_graphs_[graph_info] = kernel_graph;
|
||||
auto enable_op_graph_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
|
||||
if (enable_op_graph_cache) {
|
||||
run_op_graphs_[graph_info] = kernel_graph;
|
||||
}
|
||||
return kernel_graph;
|
||||
}
|
||||
|
||||
void CPUSession::SetOutputFlags(const VectorRef &base_ref) {
|
||||
|
@ -260,12 +266,8 @@ void CPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
|
|||
const std::vector<int64_t> &tensors_mask) {
|
||||
MS_EXCEPTION_IF_NULL(input_tensors);
|
||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
|
||||
const auto &kernel_graph = BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
|
||||
EraseValueNodeTensor(tensors_mask, input_tensors);
|
||||
|
||||
auto kernel_graph = run_op_graphs_[graph_info];
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
||||
// Remove reorder after PS feature finish adapting push/pull in auto_monad.
|
||||
auto execution_order = kernel_graph->execution_order();
|
||||
Reorder(&execution_order);
|
||||
|
|
|
@ -43,9 +43,9 @@ class CPUSession : public SessionBasic {
|
|||
void ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) override;
|
||||
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) override;
|
||||
void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) override;
|
||||
KernelGraphPtr BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) override;
|
||||
void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors,
|
||||
VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
|
||||
void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
|
|
|
@ -601,16 +601,17 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const
|
|||
}
|
||||
}
|
||||
|
||||
void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) {
|
||||
KernelGraphPtr GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) {
|
||||
// Check if the graph cache exists.
|
||||
if (run_op_graphs_.find(graph_info) != run_op_graphs_.end() &&
|
||||
kOpCacheBlackList.find(op_run_info.op_name) == kOpCacheBlackList.end()) {
|
||||
return;
|
||||
auto it = run_op_graphs_.find(graph_info);
|
||||
if (it != run_op_graphs_.end() && kOpCacheBlackList.find(op_run_info.op_name) == kOpCacheBlackList.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Prepare the graph
|
||||
auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
|
||||
const auto &kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
RunOpOptimize(kernel_graph);
|
||||
SelectKernel(kernel_graph);
|
||||
|
@ -618,7 +619,11 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap
|
|||
StartKernelRT();
|
||||
RunOpHideNopNode(kernel_graph);
|
||||
BuildKernel(kernel_graph);
|
||||
run_op_graphs_[graph_info] = kernel_graph;
|
||||
auto enable_op_graph_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
|
||||
if (enable_op_graph_cache) {
|
||||
run_op_graphs_[graph_info] = kernel_graph;
|
||||
}
|
||||
return kernel_graph;
|
||||
}
|
||||
|
||||
void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
|
||||
|
@ -626,7 +631,7 @@ void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
|
|||
const std::vector<int64_t> &tensors_mask) {
|
||||
MS_EXCEPTION_IF_NULL(input_tensors);
|
||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
|
||||
const auto &kernel_graph = BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
|
||||
EraseValueNodeTensor(tensors_mask, input_tensors);
|
||||
// wait for allreduce
|
||||
for (auto &tensor : *input_tensors) {
|
||||
|
@ -636,7 +641,6 @@ void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
|
|||
}
|
||||
}
|
||||
// run op
|
||||
auto kernel_graph = run_op_graphs_[graph_info];
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
RunOpRemoveNopNode(kernel_graph);
|
||||
RunOpAllocateMemory(*input_tensors, kernel_graph.get());
|
||||
|
|
|
@ -45,9 +45,9 @@ class GPUSession : public SessionBasic {
|
|||
void PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph, const std::vector<tensor::TensorPtr> &inputs,
|
||||
VectorRef *const outputs) override;
|
||||
void ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) override;
|
||||
void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) override;
|
||||
KernelGraphPtr BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) override;
|
||||
void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors,
|
||||
VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
|
||||
std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) override;
|
||||
|
|
|
@ -217,9 +217,11 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
const std::vector<tensor::TensorPtr> &inputs, VectorRef *const outputs) {}
|
||||
virtual void ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) {}
|
||||
void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
|
||||
virtual void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) {}
|
||||
virtual KernelGraphPtr BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) {
|
||||
return nullptr;
|
||||
}
|
||||
virtual void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
|
||||
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask) {}
|
||||
|
|
|
@ -90,7 +90,7 @@ FuncGraphPtr GetZerosLike(const abstract::AbstractBasePtrList &args_spec) {
|
|||
MS_EXCEPTION_IF_NULL(specialized_zeros_like_fg);
|
||||
auto opted_zeros_like_fg = ZerosLikePrimOptPass(resource);
|
||||
MS_EXCEPTION_IF_NULL(opted_zeros_like_fg);
|
||||
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_GRAD_CACHE);
|
||||
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
|
||||
if (enable_grad_cache) {
|
||||
zeros_like_funcgraph_cache[args_spec] = BasicClone(opted_zeros_like_fg);
|
||||
}
|
||||
|
@ -149,7 +149,7 @@ FuncGraphPtr GetOnesLike(const abstract::AbstractBasePtrList &args_spec) {
|
|||
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
|
||||
auto specialized_ones_like_fg = pipeline::Renormalize(resource, ones_like_fg, args_spec);
|
||||
MS_EXCEPTION_IF_NULL(specialized_ones_like_fg);
|
||||
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_GRAD_CACHE);
|
||||
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
|
||||
if (enable_grad_cache) {
|
||||
ones_like_funcgraph_cache[args_spec] = BasicClone(specialized_ones_like_fg);
|
||||
}
|
||||
|
|
|
@ -232,7 +232,7 @@ FuncGraphPtr PrimBpropOptimizer::GetOptBpropFromCache(const FuncGraphPtr &bprop_
|
|||
auto new_abs_list = AddOutToAbsList(out, abs_list);
|
||||
level_2_graph_info = PrimBpropOptStep2(level_1_graph, new_abs_list);
|
||||
level_2_graph_info->TryFreeArgsValue(op_args, out);
|
||||
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_GRAD_CACHE);
|
||||
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
|
||||
if (enable_grad_cache) {
|
||||
level_1_graph_info->graph_level_2_cache_[abs_list] = level_2_graph_info;
|
||||
return BasicClone(level_2_graph_info->opt_func_graph());
|
||||
|
@ -260,7 +260,7 @@ FuncGraphPtr PrimBpropOptimizer::GenSpecOptBprop(const FuncGraphPtr &bprop_fg, c
|
|||
auto new_abs_list = AddOutToAbsList(out, abs_list);
|
||||
auto level_2_graph_info = PrimBpropOptStep2(level_1_graph_info->opt_func_graph_, new_abs_list);
|
||||
level_2_graph_info->TryFreeArgsValue(op_args, out);
|
||||
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_GRAD_CACHE);
|
||||
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
|
||||
if (!hook_flg && enable_grad_cache) {
|
||||
tuple_list_bprop_cache_[std::pair(prim, abs_list)] = BasicClone(level_2_graph_info->opt_func_graph());
|
||||
}
|
||||
|
|
|
@ -702,12 +702,6 @@ py::object GetDstType(const TypeId &type_id) {
|
|||
MS_EXCEPTION_IF_NULL(value);
|
||||
return py::cast(value);
|
||||
}
|
||||
|
||||
void EnableGraphCache(bool flag) {
|
||||
const auto inst = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(inst);
|
||||
inst->set_param<bool>(MS_CTX_ENABLE_GRAD_CACHE, flag);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
py::object RealRunOp(const py::args &args) {
|
||||
|
@ -992,7 +986,7 @@ void ForwardExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
|
|||
if (shape->IsDynamic()) {
|
||||
op_exec_info->is_dynamic_shape = true;
|
||||
// Dynamic shape operator in the current top cell, disable backend cache
|
||||
EnableGraphCache(false);
|
||||
grad()->EnableOpGraphCache(false);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1016,7 +1010,7 @@ void ForwardExecutor::GetOpOutput(const OpExecInfoPtr &op_exec_info,
|
|||
}
|
||||
|
||||
// Add output abstract info into cache, the const value needs to infer evert step
|
||||
if (!prim_cache_hit && !op_exec_info->is_dynamic_shape) {
|
||||
if (grad()->enable_op_cache() && !prim_cache_hit && !op_exec_info->is_dynamic_shape) {
|
||||
AbsCacheKey key{prim->name(), prim->Hash(), prim->attrs()};
|
||||
auto &out = prim_abs_list_[key];
|
||||
out[args_spec_list].abs = op_exec_info->abstract;
|
||||
|
@ -1338,6 +1332,13 @@ TopCellInfoPtr GradExecutor::GetTopCell(const std::string &cell_id) const {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void GradExecutor::EnableOpGraphCache(bool is_enable) {
|
||||
enable_op_cache_ = is_enable;
|
||||
const auto inst = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(inst);
|
||||
inst->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, is_enable);
|
||||
}
|
||||
|
||||
void GradExecutor::RecordGradOpInfo(const OpExecInfoPtr &op_exec_info, const py::object &ret) {
|
||||
if (!grad_flag_) {
|
||||
MS_LOG(DEBUG) << "Grad flag is set to false, no need to record op info";
|
||||
|
@ -1515,7 +1516,7 @@ void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_e
|
|||
}
|
||||
|
||||
// First run top cell
|
||||
if (already_run_top_cell_.find(top_cell_->cell_id()) == already_run_top_cell_.end()) {
|
||||
if (already_run_top_cell_.find(top_cell_->already_run_cell_id()) == already_run_top_cell_.end()) {
|
||||
MS_LOG(DEBUG) << "Top cell " << top_cell_->cell_id() << " run firstly";
|
||||
if (!need_construct_graph()) {
|
||||
MS_LOG(EXCEPTION) << "The cell stack is empty when running a new top cell " << top_cell_->cell_id();
|
||||
|
@ -1523,7 +1524,7 @@ void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_e
|
|||
return;
|
||||
}
|
||||
// Non-first run
|
||||
const auto &pre_top_cell = already_run_top_cell_.at(top_cell_->cell_id());
|
||||
const auto &pre_top_cell = already_run_top_cell_.at(top_cell_->already_run_cell_id());
|
||||
MS_EXCEPTION_IF_NULL(pre_top_cell);
|
||||
if (pre_top_cell->op_info_with_tensor_id().find(op_info) == pre_top_cell->op_info_with_tensor_id().end()) {
|
||||
MS_LOG(DEBUG) << "Can not find op info " << op_info << " in op info with tensor id map. Top cell "
|
||||
|
@ -1895,13 +1896,12 @@ void GradExecutor::ClearCellRes(const std::string &cell_id) {
|
|||
}
|
||||
// clear when cell destruction
|
||||
for (auto it = top_cell_list_.begin(); it != top_cell_list_.end();) {
|
||||
auto top_cell_id = (*it)->cell_id();
|
||||
const auto &top_cell_id = (*it)->cell_id();
|
||||
const auto &alreay_top_cell_id = (*it)->already_run_cell_id();
|
||||
if (IsCellObjIdEq(cell_id, top_cell_id)) {
|
||||
(*it)->Clear();
|
||||
it = top_cell_list_.erase(it);
|
||||
if (already_run_top_cell_.find(top_cell_id) != already_run_top_cell_.end()) {
|
||||
(void)already_run_top_cell_.erase(top_cell_id);
|
||||
}
|
||||
(void)already_run_top_cell_.erase(alreay_top_cell_id);
|
||||
MS_LOG(DEBUG) << "Clear top cell resource. Top cell id " << top_cell_id;
|
||||
continue;
|
||||
}
|
||||
|
@ -1952,7 +1952,7 @@ void GradExecutor::HandleInputArgsForTopCell(const py::args &args, bool is_bprop
|
|||
}
|
||||
// Convert input args to parameters for top cell graph in construct.
|
||||
std::vector<ValuePtr> input_param_values;
|
||||
py::list only_tensors = FilterTensorArgs(args);
|
||||
const auto &only_tensors = FilterTensorArgs(args);
|
||||
auto df_builder = GetDfbuilder(top_cell_->cell_id());
|
||||
MS_EXCEPTION_IF_NULL(df_builder);
|
||||
for (size_t i = 0; i < only_tensors.size(); ++i) {
|
||||
|
@ -2017,11 +2017,18 @@ void GradExecutor::InitResourceAndDfBuilder(const std::string &cell_id, const py
|
|||
|
||||
void GradExecutor::NewGraphInner(py::object *ret, const py::object &cell, const py::args &args) {
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
auto cell_id = GetCellId(cell, args);
|
||||
const auto &cell_id = GetCellId(cell, args);
|
||||
MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id;
|
||||
if (top_cell_ != nullptr && cell_stack_.empty()) {
|
||||
// Already run top cell need distinguish high order; high order add "0" otherwise "1"
|
||||
std::string already_run_cell_id;
|
||||
if (IsNestedGrad()) {
|
||||
already_run_cell_id = cell_id + "0";
|
||||
} else {
|
||||
already_run_cell_id = cell_id + "1";
|
||||
}
|
||||
// Whether it is top and has been run
|
||||
auto top_it = already_run_top_cell_.find(cell_id);
|
||||
auto top_it = already_run_top_cell_.find(already_run_cell_id);
|
||||
if (top_it != already_run_top_cell_.end()) {
|
||||
// Top cell forward run.
|
||||
const auto &pre_top_cell = top_it->second;
|
||||
|
@ -2032,8 +2039,8 @@ void GradExecutor::NewGraphInner(py::object *ret, const py::object &cell, const
|
|||
set_top_cell(pre_top_cell);
|
||||
return;
|
||||
}
|
||||
} else if ((top_cell()->IsSubCell(cell_id) && !IsCellObjIdEq(cell_id, check_graph_cell_id_)) ||
|
||||
GetHighOrderStackSize() >= 1) {
|
||||
} else if ((top_cell()->IsSubCell(cell_id) || GetHighOrderStackSize() >= 1) &&
|
||||
!IsCellObjIdEq(cell_id, check_graph_cell_id_)) {
|
||||
// Sub cell ( or may be a temporary cell, but must be non top) forward run in cache process.
|
||||
MS_LOG(DEBUG) << "Sub cell no need to run NewGraphInner again";
|
||||
return;
|
||||
|
@ -2069,13 +2076,11 @@ void GradExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args,
|
|||
// The number of top cell exceeds MAX_TOP_CELL_COUNTS, delete the last one to keep the maximum length of the list,
|
||||
// disable backend cache
|
||||
if (top_cell_list_.size() >= MAX_TOP_CELL_COUNTS) {
|
||||
EnableGraphCache(false);
|
||||
EnableOpGraphCache(false);
|
||||
const auto last_top_cell = top_cell_list_.back();
|
||||
top_cell_list_.pop_back();
|
||||
last_top_cell->Clear();
|
||||
if (already_run_top_cell_.find(last_top_cell->cell_id()) != already_run_top_cell_.end()) {
|
||||
(void)already_run_top_cell_.erase(last_top_cell->cell_id());
|
||||
}
|
||||
(void)already_run_top_cell_.erase(last_top_cell->already_run_cell_id());
|
||||
}
|
||||
// Create top cell
|
||||
curr_g_ = std::make_shared<FuncGraph>();
|
||||
|
@ -2535,16 +2540,16 @@ py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::a
|
|||
|
||||
void GradExecutor::CheckNeedCompileGraph() {
|
||||
auto new_top_cell = top_cell();
|
||||
std::string top_cell_id = new_top_cell->cell_id();
|
||||
// update top cell by current cell op info
|
||||
if (already_run_top_cell_.find(top_cell_id) == already_run_top_cell_.end()) {
|
||||
MS_LOG(DEBUG) << "Top cell " << top_cell_id << " has never been ran, need compile graph";
|
||||
already_run_top_cell_[top_cell_id] = new_top_cell;
|
||||
const auto &already_top_cell_id = new_top_cell->already_run_cell_id();
|
||||
// Update top cell by current cell op info
|
||||
if (already_run_top_cell_.find(already_top_cell_id) == already_run_top_cell_.end()) {
|
||||
MS_LOG(DEBUG) << "Top cell " << new_top_cell->cell_id() << " has never been ran, need compile graph";
|
||||
already_run_top_cell_[already_top_cell_id] = new_top_cell;
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "Top cell " << top_cell_id << " has been ran";
|
||||
auto pre_top_cell = already_run_top_cell_.at(top_cell_id);
|
||||
MS_LOG(DEBUG) << "Top cell " << new_top_cell->cell_id() << " has been ran";
|
||||
auto pre_top_cell = already_run_top_cell_.at(already_top_cell_id);
|
||||
auto pre_all_op_info = pre_top_cell->all_op_info();
|
||||
auto new_all_op_info = new_top_cell->all_op_info();
|
||||
MS_LOG(DEBUG) << "Pre all op info : " << pre_all_op_info;
|
||||
|
@ -2553,14 +2558,14 @@ void GradExecutor::CheckNeedCompileGraph() {
|
|||
MS_LOG(DEBUG) << "The op info has been changed, need to compile graph again";
|
||||
// The top cell switches exceeds MAX_TOP_CELL_COUNTS under the control flow, disable backend cache
|
||||
if (top_cell_switch_counts_ >= MAX_TOP_CELL_COUNTS) {
|
||||
EnableGraphCache(false);
|
||||
EnableOpGraphCache(false);
|
||||
} else {
|
||||
// Increase top cell switches counts
|
||||
++top_cell_switch_counts_;
|
||||
}
|
||||
EraseTopCellFromTopCellList(pre_top_cell);
|
||||
pre_top_cell->Clear();
|
||||
already_run_top_cell_[top_cell_id] = new_top_cell;
|
||||
already_run_top_cell_[already_top_cell_id] = new_top_cell;
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "The op info has not been changed, no need to compile graph again";
|
||||
pre_top_cell->set_input_args_id(new_top_cell->input_args_id());
|
||||
|
@ -2813,6 +2818,7 @@ void GradExecutor::ClearRes() {
|
|||
grad_flag_ = false;
|
||||
need_renormalize_ = false;
|
||||
grad_is_running_ = false;
|
||||
enable_op_cache_ = true;
|
||||
top_cell_ = nullptr;
|
||||
curr_g_ = nullptr;
|
||||
bprop_cell_list_.clear();
|
||||
|
|
|
@ -42,7 +42,6 @@
|
|||
|
||||
namespace mindspore::pynative {
|
||||
namespace py = pybind11;
|
||||
using CellId = std::string;
|
||||
using MsFunctionGradCache = std::unordered_map<std::string, std::pair<FuncGraphPtr, FuncGraphPtr>>;
|
||||
using OpInfoWithTensorId = std::unordered_map<std::string, std::vector<std::string>>;
|
||||
using TensorIdWithTensorObject = std::unordered_map<std::string, std::vector<tensor::TensorPtr>>;
|
||||
|
@ -68,7 +67,8 @@ class TopCellInfo {
|
|||
grad_order_(grad_order),
|
||||
resource_(std::move(r)),
|
||||
df_builder_(std::move(df)),
|
||||
cell_id_(std::move(cellid)) {}
|
||||
cell_id_(std::move(cellid)),
|
||||
alread_run_cell_id_(cell_id_ + std::to_string(is_topest_)) {}
|
||||
|
||||
bool is_init_kpynative() const { return is_init_kpynative_; }
|
||||
void set_init_kpynative(bool init) { is_init_kpynative_ = init; }
|
||||
|
@ -90,9 +90,10 @@ class TopCellInfo {
|
|||
size_t op_num() const { return op_num_; }
|
||||
void set_op_num(size_t op_num) { op_num_ = op_num; }
|
||||
std::string &cell_id() { return cell_id_; }
|
||||
std::string &already_run_cell_id() { return alread_run_cell_id_; }
|
||||
std::string &input_args_id() { return input_args_id_; }
|
||||
std::string &all_op_info() { return all_op_info_; }
|
||||
void set_input_args_id(const std::string &input_args_id) { input_args_id_ = std::move(input_args_id); }
|
||||
void set_input_args_id(const std::string &input_args_id) { input_args_id_ = input_args_id; }
|
||||
std::unordered_set<std::string> &sub_cell_list() { return sub_cell_list_; }
|
||||
bool IsSubCell(const std::string &cell_id) const;
|
||||
OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map() { return graph_info_map_; }
|
||||
|
@ -124,6 +125,7 @@ class TopCellInfo {
|
|||
FuncGraphPtr df_builder_{nullptr};
|
||||
ad::KPynativeCellPtr k_pynative_cell_ptr_{nullptr};
|
||||
std::string cell_id_;
|
||||
std::string alread_run_cell_id_;
|
||||
std::string input_args_id_;
|
||||
std::string all_op_info_;
|
||||
OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_;
|
||||
|
@ -173,7 +175,9 @@ class GradExecutor {
|
|||
TopCellInfoPtr top_cell() const;
|
||||
void CheckNeedCompileGraph();
|
||||
TopCellInfoPtr GetTopCell(const string &cell_id) const;
|
||||
void EnableOpGraphCache(bool is_enable);
|
||||
bool need_renormalize() const { return need_renormalize_; }
|
||||
bool enable_op_cache() const { return enable_op_cache_; }
|
||||
void set_top_cell(TopCellInfoPtr top_cell) { top_cell_ = std::move(top_cell); }
|
||||
bool grad_flag() const { return grad_flag_; }
|
||||
void set_grad_flag(bool flag) { grad_flag_ = flag; }
|
||||
|
@ -242,15 +246,15 @@ class GradExecutor {
|
|||
const std::vector<int64_t> &index_sequence, bool is_param = false);
|
||||
void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
|
||||
bool is_param = false);
|
||||
void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr ¶m) {
|
||||
void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr ¶m) const {
|
||||
top_cell()->graph_info_map()[g]->params[id] = param;
|
||||
}
|
||||
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
|
||||
int64_t index = -1) {
|
||||
int64_t index = -1) const {
|
||||
top_cell()->graph_info_map()[g]->node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
|
||||
}
|
||||
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
|
||||
const std::vector<int64_t> &index) {
|
||||
const std::vector<int64_t> &index) const {
|
||||
top_cell()->graph_info_map()[g]->node_map[id] = std::make_pair(node, index);
|
||||
}
|
||||
void CreateMakeTupleNodeForMultiOut(const FuncGraphPtr &curr_g, const py::object &out, const std::string &out_id);
|
||||
|
@ -259,6 +263,7 @@ class GradExecutor {
|
|||
bool grad_flag_{false};
|
||||
bool need_renormalize_{false};
|
||||
bool grad_is_running_{false};
|
||||
bool enable_op_cache_{true};
|
||||
int custom_bprop_cell_count_{0};
|
||||
size_t grad_order_{0};
|
||||
size_t top_cell_switch_counts_{0};
|
||||
|
@ -281,7 +286,7 @@ class GradExecutor {
|
|||
// Use vector for keep order
|
||||
std::vector<TopCellInfoPtr> top_cell_list_;
|
||||
// Record all top cell which has been ran
|
||||
std::map<CellId, TopCellInfoPtr> already_run_top_cell_;
|
||||
std::unordered_map<std::string, TopCellInfoPtr> already_run_top_cell_;
|
||||
// Use vector for keep order
|
||||
ForwardExecutorWeakPtr forward_executor_;
|
||||
};
|
||||
|
|
|
@ -101,7 +101,6 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
|
|||
.value("grad_for_scalar", MsCtxParam::MS_CTX_GRAD_FOR_SCALAR)
|
||||
.value("save_compile_cache", MsCtxParam::MS_CTX_SAVE_COMPILE_CACHE)
|
||||
.value("load_compile_cache", MsCtxParam::MS_CTX_LOAD_COMPILE_CACHE)
|
||||
.value("enable_grad_cache", MsCtxParam::MS_CTX_ENABLE_GRAD_CACHE)
|
||||
.value("pynative_synchronize", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE);
|
||||
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext")
|
||||
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
|
||||
|
|
|
@ -484,7 +484,7 @@ const ActorInfo &MindRTBackend::CompileGraph(const OpRunInfo &op_run_info, const
|
|||
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool enable_cache = context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAD_CACHE);
|
||||
bool enable_cache = context_ptr->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
|
||||
auto graph_compiler_info = ConstructGraphCompilerInfo(actor_info, tensors_mask, input_tensors, !enable_cache);
|
||||
const auto actor_set = runtime::GraphScheduler::GetInstance().Transform(*graph_compiler_info);
|
||||
runtime::GraphScheduler::GetInstance().Schedule(actor_set);
|
||||
|
|
|
@ -516,7 +516,7 @@ def _check_target_specific_cfgs(device, arg_key):
|
|||
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
|
||||
enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str,
|
||||
enable_sparse=bool, max_call_depth=int, env_config_path=str, graph_kernel_flags=str,
|
||||
save_compile_cache=bool, load_compile_cache=bool, grad_for_scalar=bool, enable_grad_cache=bool)
|
||||
save_compile_cache=bool, load_compile_cache=bool, grad_for_scalar=bool)
|
||||
def set_context(**kwargs):
|
||||
"""
|
||||
Set context for running environment.
|
||||
|
@ -554,7 +554,6 @@ def set_context(**kwargs):
|
|||
grad_for_scalar
|
||||
save_compile_cache
|
||||
load_compile_cache
|
||||
enable_grad_cache
|
||||
=========================== =========================== =================
|
||||
|
||||
Args:
|
||||
|
@ -666,9 +665,6 @@ def set_context(**kwargs):
|
|||
you should make sure the network has not been changed since the last execution. By now, we have
|
||||
not support automatically checking the changes yet. Default: False.
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
enable_grad_cache (bool): Whether to use cache for grad, default True.
|
||||
The cache will cost memory for every compiled graph.
|
||||
If the input data shape is uncertian, advised to disable the cache for save memory.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not an attribute in context.
|
||||
|
@ -692,7 +688,6 @@ def set_context(**kwargs):
|
|||
>>> context.set_context(print_file_path="print.pb")
|
||||
>>> context.set_context(max_call_depth=80)
|
||||
>>> context.set_context(env_config_path="./env_config.json")
|
||||
>>> context.set_context(enable_grad_cache=True)
|
||||
"""
|
||||
ctx = _context()
|
||||
# set device target first
|
||||
|
|
|
@ -89,8 +89,8 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
|||
set_param<bool>(MS_CTX_LOAD_COMPILE_CACHE, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
|
||||
set_param<bool>(MS_CTX_ALREADY_SET_ENABLE_MINDRT, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_GRAD_CACHE, true);
|
||||
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, true);
|
||||
|
||||
backend_policy_ = policy_map_[policy];
|
||||
}
|
||||
|
|
|
@ -90,8 +90,8 @@ enum MsCtxParam : unsigned {
|
|||
MS_CTX_LOAD_COMPILE_CACHE,
|
||||
MS_CTX_ENABLE_MINDRT,
|
||||
MS_CTX_ALREADY_SET_ENABLE_MINDRT,
|
||||
MS_CTX_ENABLE_GRAD_CACHE,
|
||||
MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE,
|
||||
MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE,
|
||||
MS_CTX_TYPE_BOOL_END,
|
||||
|
||||
// parameter of type int
|
||||
|
|
Loading…
Reference in New Issue