diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 5728b7109ee..0bd6299fdb2 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -63,7 +63,6 @@ #include "debug/anf_ir_dump.h" using mindspore::tensor::TensorPy; - const size_t PTR_LEN = 15; // primitive unable to infer value for constant input in PyNative mode static const std::set vm_operators = {"make_ref", "HookBackward", "InsertGradientOf", "stop_gradient", @@ -637,16 +636,11 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) { auto op_name = py::cast(args[PY_NAME]); op_exec_info->op_name = op_name; if (grad_flag()) { - int64_t graph_id = graph_id_; - auto resource = GetResource(top_cell_id_); - if (resource != nullptr) { - MS_LOG(DEBUG) << "Get resource ptr " << resource.get(); - auto it = resource->results().find(pipeline::kPynativeGraphId); - if (it != resource->results().end()) { - graph_id = it->second.cast(); - } + op_exec_info->op_index = op_name + std::to_string(op_index_map_[op_name]); + if (!cell_op_info_stack_.empty()) { + std::string &cell_op_info = cell_op_info_stack_.top(); + cell_op_info += op_exec_info->op_index; } - op_exec_info->op_index = std::to_string(graph_id) + op_name + std::to_string(op_index_map_[op_name]); op_index_map_[op_name]++; } auto prim = py::cast(args[PY_PRIM]); @@ -968,9 +962,11 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { MS_LOG(EXCEPTION) << "Parameter object should have name attribute"; } auto param_name = py::cast(name_attr); - auto df_builder = GetDfbuilder(); + auto df_builder = GetDfbuilder(top_cell_id_); MS_EXCEPTION_IF_NULL(df_builder); - if (graph_info_map_.at(df_builder).params.find(obj_id) == graph_info_map_.at(df_builder).params.end()) { + auto graph_info = graph_info_map_.at(df_builder); + MS_EXCEPTION_IF_NULL(graph_info); + if (graph_info->params.find(obj_id) == graph_info->params.end()) { auto free_param = df_builder->add_parameter(); free_param->set_name(param_name); free_param->debug_info()->set_name(param_name); @@ -983,12 +979,14 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { SetNodeMapInGraphInfoMap(curr_g_, obj_id, free_param); return free_param; } - node = graph_info_map_.at(df_builder).node_map[obj_id].first; + node = graph_info->node_map.at(obj_id).first; MS_LOG(DEBUG) << "Get input param node " << node->ToString() << " " << obj_id; return node; } - if (graph_info_map_.at(curr_g_).node_map.find(obj_id) != graph_info_map_.at(curr_g_).node_map.end()) { + auto graph_info = graph_info_map_.at(curr_g_); + MS_EXCEPTION_IF_NULL(graph_info); + if (graph_info->node_map.find(obj_id) != graph_info->node_map.end()) { // op(x, y) // out = op(op1(x, y)) // out = op(cell1(x, y)) @@ -1099,14 +1097,13 @@ void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) { } } -void PynativeExecutor::CleanPreMemoryInValueNode(const std::string &cell_id) { +void PynativeExecutor::CleanPreMemoryInValueNode() { auto ms_context = MsContext::GetInstance(); std::string device_target = ms_context->get_param(MS_CTX_DEVICE_TARGET); if (device_target == "CPU") { - top_cell_id_ = cell_id; return; } - if (dynamic_cell_) { + if (has_dynamic_cell_) { std::set forward_op_tensor_id; for (const auto &elem : cell_op_index_with_tensor_id_[top_cell_id_]) { const auto &tensor_id_list = elem.second; @@ -1131,18 +1128,19 @@ void PynativeExecutor::CleanPreMemoryInValueNode(const std::string &cell_id) { tensor->set_device_address(nullptr); } } - top_cell_id_ = cell_id; } AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) { - auto &out = graph_info_map_.at(curr_g_).node_map[obj_id]; + auto graph_info = graph_info_map_.at(curr_g_); + MS_EXCEPTION_IF_NULL(graph_info); + auto &out = graph_info->node_map.at(obj_id); if (out.second.size() == 1 && out.second[0] == -1) { return out.first; } MS_LOG(DEBUG) << "Output size " << out.second.size(); // Params node - if (graph_info_map_.at(curr_g_).params.find(obj_id) != graph_info_map_.at(curr_g_).params.end()) { + if (graph_info->params.find(obj_id) != graph_info->params.end()) { auto para_node = out.first; for (auto &idx : out.second) { std::vector tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), para_node, @@ -1463,6 +1461,11 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati void PynativeExecutor::PushCurrentGraphToStack() { graph_stack_.push(curr_g_); } +void PynativeExecutor::PushCurrentCellOpInfoToStack() { + std::string cell_op_info = "Cell ops: "; + cell_op_info_stack_.push(cell_op_info); +} + void PynativeExecutor::PopGraphStack() { if (graph_stack_.empty()) { MS_LOG(EXCEPTION) << "Stack graph_stack_ is empty"; @@ -1473,6 +1476,13 @@ void PynativeExecutor::PopGraphStack() { } } +void PynativeExecutor::PopCurrentCellOpInfoFromStack() { + if (cell_op_info_stack_.empty()) { + MS_LOG(EXCEPTION) << "The cell op info stack is empty"; + } + cell_op_info_stack_.pop(); +} + std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &args) { auto cell_id = GetId(cell); for (size_t i = 0; i < args.size(); i++) { @@ -1480,14 +1490,53 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args & auto it = node_abs_map_.find(arg_id); if (it != node_abs_map_.end()) { cell_id += "_" + it->second->BuildShape()->ToString(); - cell_id += "_" + it->second->BuildType()->ToString(); + cell_id += it->second->BuildType()->ToString(); } else { auto abs = PyAttrValue(args[i])->ToAbstract(); auto config = abstract::AbstractBase::kBroadenTensorOnly; abs = abs->Broaden(config); - cell_id += "_" + abs->BuildShape()->ToString(); - cell_id += "_" + abs->BuildType()->ToString(); node_abs_map_[arg_id] = abs; + cell_id += "_" + abs->BuildShape()->ToString(); + cell_id += abs->BuildType()->ToString(); + } + } + return GetTensorCellId(cell_id); +} + +std::string PynativeExecutor::GetTensorCellId(const std::string &cell_id) { + if (cell_id.find("NoShape") == std::string::npos) { + return cell_id; + } + std::string key = cell_id.substr(0, PTR_LEN); + auto fn = [](const std::string &str, std::vector &value) { + size_t pos = 0; + size_t pre_pos = 0; + while ((pos = str.find_first_of('_', pre_pos)) != std::string::npos) { + value.emplace_back(str.substr(pre_pos, pos - pre_pos + 1)); + pre_pos = pos + 1; + } + value.emplace_back(str.substr(pre_pos)); + }; + auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), [&key](const CellInfoPtr &value) { + return value->cell_id.find(key) != std::string::npos && value->cell_id.find("Tensor") != std::string::npos; + }); + if (it != cell_graph_list_.end()) { + std::vector pre_cell_id; + std::vector cur_cell_id; + fn((*it)->cell_id, pre_cell_id); + fn(cell_id, cur_cell_id); + auto pre_tensor_size = pre_cell_id.size(); + if (pre_tensor_size == cur_cell_id.size()) { + size_t same_tensor_count = 0; + for (size_t i = 0; i < pre_tensor_size; ++i) { + if (cur_cell_id[i].find("NoShape") != std::string::npos || cur_cell_id[i] == pre_cell_id[i]) { + ++same_tensor_count; + } + } + if (same_tensor_count == pre_tensor_size) { + MS_LOG(DEBUG) << "Changed cell id from " << cell_id << " to " << (*it)->cell_id; + return (*it)->cell_id; + } } } return cell_id; @@ -1499,22 +1548,37 @@ void PynativeExecutor::DumpGraphIR(const std::string &filename, const FuncGraphP } } -bool PynativeExecutor::IsNotNestedGrad() const { - MS_LOG(DEBUG) << "Grad nested count is " << grad_order_; - return grad_order_ <= 1; +bool PynativeExecutor::IsNestedGrad() const { + MS_LOG(DEBUG) << "Grad nested order is " << grad_order_; + return grad_order_ > 1; } bool PynativeExecutor::IsTopGraph(const std::string &cell_id) { return std::any_of(top_cell_list_.begin(), top_cell_list_.end(), - [&cell_id](const TopCellInfo &value) { return value.cell_id == cell_id; }); + [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id; }); +} + +bool PynativeExecutor::IsTopestGraph(const std::string &cell_id) { + return std::any_of(top_cell_list_.begin(), top_cell_list_.end(), + [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id && value->is_topest; }); +} + +void PynativeExecutor::UpdateTopCellCompileInfo(const std::string &cell_id, bool vm_compiled) { + auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), + [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id; }); + if (it != top_cell_list_.end()) { + (*it)->do_vm_compiled = vm_compiled; + } } bool PynativeExecutor::IsBpropGraph(const std::string &cell_id) { - return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfo &value) { - return !value.bprop_cell_id.empty() && cell_id.find(value.bprop_cell_id) != std::string::npos; + return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { + return !value->bprop_cell_id.empty() && cell_id.find(value->bprop_cell_id) != std::string::npos; }); } +bool PynativeExecutor::IsFirstGradStep(const std::string &cell_id) { return !CheckCellGraph(cell_id, true); } + void PynativeExecutor::SubNestedGradOrder() { if (grad_order_ > 0) { --grad_order_; @@ -1522,27 +1586,31 @@ void PynativeExecutor::SubNestedGradOrder() { } bool PynativeExecutor::CheckCellGraph(const std::string &cell_id, bool is_grad) { - auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), [&cell_id](const TopCellInfo &value) { - return value.cell_id == cell_id && value.is_dynamic_cell; + return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id, is_grad](const CellInfoPtr &value) { + return value->cell_id == cell_id && (!is_grad || value->is_grad); }); - if (it != top_cell_list_.end()) { - return false; - } - return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id, is_grad](const CellInfo &value) { - return value.cell_id == cell_id && (!is_grad || value.is_grad); +} + +bool PynativeExecutor::CheckDynamicCell(const std::string &cell_id) { + return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), + [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id && value->is_dynamic; }); +} + +bool PynativeExecutor::CheckRealDynamicCell(const std::string &cell_id) { + return std::any_of(cell_graph_list_.begin(), cell_graph_list_.end(), [&cell_id](const CellInfoPtr &value) { + return value->cell_id == cell_id && value->is_real_dynamic; }); } void PynativeExecutor::ClearResidualRes(const std::string &cell_id) { + // Abnormal case if (top_cell_list_.empty() && !graph_stack_.empty()) { graph_id_ = 0; graph_info_map_.clear(); - cell_sw_map_.clear(); cell_graph_list_.clear(); std::stack().swap(graph_stack_); } - if (dynamic_cell_) { - VectorClear>(&top_cell_list_, cell_id); + if (CheckRealDynamicCell(cell_id)) { if (IsTopGraph(cell_id) && graph_stack_.empty() && !IsBpropGraph(cell_id)) { // Clear previous step resource auto resource = GetResource(cell_id); @@ -1556,62 +1624,28 @@ void PynativeExecutor::ClearResidualRes(const std::string &cell_id) { } FuncGraphPtr PynativeExecutor::GetDfbuilder(const std::string &cell_id) { - // Cell is empty, get nearest dfbuilder - if (cell_id.empty() && !top_cell_list_.empty()) { - if (top_cell_list_.size() == 1) { - return top_cell_list_.begin()->df_builder; - } - if (grad_order_ == 0 || grad_order_ == 1) { - return top_cell_list_.back().df_builder; - } - if (top_cell_list_.size() < 2) { - MS_LOG(EXCEPTION) << "Top cell list size must greater than 2"; - } - MS_LOG(DEBUG) << "Get grad order " << grad_order_ << " top cell list size " << top_cell_list_.size(); - // Grad order greater than 2 - auto it = top_cell_list_.end(); - std::advance(it, -2); - return it->df_builder; - } // If top graph hold - for (const auto &it : top_cell_list_) { - if (cell_id.find(it.cell_id) != std::string::npos) { - return it.df_builder; + for (auto it = top_cell_list_.rbegin(); it != top_cell_list_.rend(); ++it) { + if (cell_id.find((*it)->cell_id) != std::string::npos) { + return (*it)->df_builder; } } // Current cell is not top graph, get first top cell if (!top_cell_list_.empty()) { - return top_cell_list_.front().df_builder; + return top_cell_list_.front()->df_builder; } return nullptr; } ResourcePtr PynativeExecutor::GetResource(const std::string &cell_id) { - // Cell is empty, get nearest resource - if (cell_id.empty() && !top_cell_list_.empty()) { - if (top_cell_list_.size() == 1) { - return top_cell_list_.begin()->resource; - } - if (grad_order_ == 0 || grad_order_ == 1) { - return top_cell_list_.back().resource; - } - if (top_cell_list_.size() < 2) { - MS_LOG(EXCEPTION) << "Top cell list size must greater than 2"; - } - MS_LOG(DEBUG) << "Get grad order " << grad_order_ << " top cell list size " << top_cell_list_.size(); - // Grad order greater than 2 - auto it = top_cell_list_.end(); - std::advance(it, -2); - return it->resource; - } - for (const auto &it : top_cell_list_) { - if (cell_id.find(it.cell_id) != std::string::npos) { - return it.resource; + for (auto it = top_cell_list_.rbegin(); it != top_cell_list_.rend(); ++it) { + if (cell_id.find((*it)->cell_id) != std::string::npos) { + return (*it)->resource; } } // Current cell is not top graph, get first top cell if (!top_cell_list_.empty()) { - return top_cell_list_.front().resource; + return top_cell_list_.front()->resource; } return nullptr; } @@ -1830,28 +1864,31 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg auto cell_id = GetCellId(cell, args); MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id; // check whether cell needed to construct grad graph - if (!dynamic_cell_ && graph_stack_.empty() && CheckCellGraph(cell_id)) { + if (graph_stack_.empty() && CheckCellGraph(cell_id) && !CheckDynamicCell(cell_id)) { if (top_cell_list_.empty()) { MS_LOG(EXCEPTION) << "Top cell list is empty"; } - if (IsTopGraph(cell_id)) { + if (IsTopestGraph(cell_id)) { // Clear previous step resource + CleanPreMemoryInValueNode(); op_index_map_.clear(); - CleanPreMemoryInValueNode(cell_id); + top_cell_id_ = cell_id; } + PushCurrentCellOpInfoToStack(); MS_LOG(INFO) << "NewGraph already compiled"; return; } - // init resource for constructing forward graph and grad graph + // Init resource for constructing forward graph and grad graph curr_g_ = std::make_shared(); ClearResidualRes(cell_id); if (graph_stack_.empty() && !IsBpropGraph(cell_id)) { - MakeNewTopGraph(cell_id, args, curr_g_); + MakeNewTopGraph(cell_id, args); } PushCurrentGraphToStack(); + PushCurrentCellOpInfoToStack(); if (graph_info_map_.find(curr_g_) == graph_info_map_.end()) { - GraphInfo graph_info = GraphInfo(cell_id); - graph_info_map_.emplace(curr_g_, graph_info); + auto graph_info = std::make_shared(cell_id); + graph_info_map_[curr_g_] = graph_info; } for (size_t i = 0; i < args.size(); ++i) { auto param = args[i]; @@ -1861,21 +1898,14 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg SetNodeMapInGraphInfoMap(curr_g_, param_id, new_param); SetParamNodeMapInGraphInfoMap(curr_g_, param_id, new_param); } - // check whether the construct of cell will be changed - if (!dynamic_cell_) { - dynamic_cell_ = IsDynamicCell(cell); - if (dynamic_cell_) { - auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), - [&](const TopCellInfo &value) { return value.cell_id == top_cell_id_; }); - if (it != top_cell_list_.end()) { - it->is_dynamic_cell = dynamic_cell_; - } - } - MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << dynamic_cell_; + // Check whether the construct of cell will be changed + if (!has_dynamic_cell_) { + has_dynamic_cell_ = IsDynamicCell(cell); + MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << has_dynamic_cell_; } } -void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g) { +void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args) { for (const auto &arg : args) { if (py::isinstance(arg)) { auto tensor = arg.cast(); @@ -1885,26 +1915,45 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar } } // Clear resource in old top cell - auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), - [&cell_id](const TopCellInfo &value) { return value.cell_id == cell_id; }); - if (it != top_cell_list_.end()) { - top_cell_list_.erase(it); + if (CheckRealDynamicCell(cell_id)) { + VectorClear>(&top_cell_list_, cell_id); } - op_index_map_.clear(); - CleanPreMemoryInValueNode(cell_id); + CleanPreMemoryInValueNode(); // Init resource for new top cell - dynamic_cell_ = false; + if (!CheckCellGraph(cell_id)) { + has_dynamic_cell_ = false; + } + op_index_map_.clear(); + top_cell_id_ = cell_id; auto df_builder = std::make_shared(); - GraphInfo graph_info = GraphInfo(cell_id); - graph_info_map_.emplace(df_builder, graph_info); + auto graph_info = std::make_shared(cell_id); + graph_info_map_[df_builder] = graph_info; auto resource = std::make_shared(); resource->results()[pipeline::kPynativeGraphId] = graph_id_++; - auto top_cell_info = TopCellInfo(resource, df_builder, nullptr, cell_id); + auto top_cell_info = std::make_shared(true, resource, df_builder, cell_id); top_cell_list_.emplace_back(top_cell_info); MS_LOG(DEBUG) << "New top graph, df_builder ptr " << df_builder.get() << " resource ptr " << resource.get(); } +std::string PynativeExecutor::GetCellOpInfo() { + if (cell_op_info_stack_.empty()) { + MS_LOG(EXCEPTION) << "The cell op info stack is empty"; + } + return cell_op_info_stack_.top(); +} + +void PynativeExecutor::ReplaceCellOpInfoByCellId(const std::string &cell_id) { + if (cell_id.empty()) { + MS_LOG(EXCEPTION) << "The cell id is empty"; + } + if (cell_op_info_stack_.empty()) { + MS_LOG(DEBUG) << "The cell op info stack is empty, No need replace"; + return; + } + cell_op_info_stack_.top() = cell_op_info_stack_.top() + cell_id; +} + void PynativeExecutor::SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node, bool is_param) { if (!py::isinstance(args) && !py::isinstance(args)) { @@ -1949,13 +1998,16 @@ void PynativeExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, con void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) { const auto &cell_id = GetCellId(cell, args); MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id; - if (!dynamic_cell_ && graph_stack_.empty() && CheckCellGraph(cell_id)) { + if (graph_stack_.empty() && CheckCellGraph(cell_id) && !CheckDynamicCell(cell_id)) { + PopCurrentCellOpInfoFromStack(); MS_LOG(INFO) << "Endgraph already compiled"; return; } auto out_id = GetId(out); // x =op1, y =op2, return (x, y) - if (graph_info_map_.at(curr_g_).node_map.find(out_id) == graph_info_map_.at(curr_g_).node_map.end()) { + auto graph_info = graph_info_map_.at(curr_g_); + MS_EXCEPTION_IF_NULL(graph_info); + if (graph_info->node_map.find(out_id) == graph_info->node_map.end()) { if (py::isinstance(out) || py::isinstance(out)) { auto tuple = out.cast(); auto tuple_size = static_cast(tuple.size()); @@ -1985,17 +2037,26 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string MS_LOG(DEBUG) << "Get bprop function cell"; return; } - auto resource = GetResource(cell_id); + auto resource = GetResource(top_cell_id_); MS_EXCEPTION_IF_NULL(resource); resource->manager()->AddFuncGraph(curr_g_); UpdateCellGraph(cell, curr_g_, cell_id, true, false); - auto newfg = MakeGradGraph(cell, curr_g_, resource, cell_id, args); + FuncGraphPtr newfg = nullptr; + // Cell no Change + if (CheckDynamicCell(cell_id) && !CheckCellChanged(cell_id)) { + MS_LOG(DEBUG) << "Cell is not dynamic, No need make ad grad"; + } else { + MS_LOG(DEBUG) << "Need make ad grad"; + newfg = MakeGradGraph(cell, curr_g_, resource, cell_id, args); + } if (graph_stack_.size() > 1) { std::vector inputs; inputs.emplace_back(NewValueNode(curr_g_)); PopGraphStack(); + PopCurrentCellOpInfoFromStack(); + ReplaceCellOpInfoByCellId(cell_id); // connect the previous graph to the inside graph auto graph_prev = graph_stack_.top(); for (size_t i = 0; i < args.size(); i++) { @@ -2007,70 +2068,164 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode); SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode); } else { - DumpGraphIR("before_resolve.ir", newfg); - parse::ResolveFuncGraph(newfg, resource); - DumpGraphIR("after_resolve.ir", newfg); - resource->set_func_graph(newfg); + if (newfg != nullptr) { + DumpGraphIR("before_resolve.ir", newfg); + parse::ResolveFuncGraph(newfg, resource); + DumpGraphIR("after_resolve.ir", newfg); + resource->set_func_graph(newfg); + } PopGraphStack(); + PopCurrentCellOpInfoFromStack(); } } bool PynativeExecutor::EndBpropGraph(const string &cell_id) { auto is_bprop_graph = IsBpropGraph(cell_id); if (is_bprop_graph) { - if (IsNotNestedGrad()) { + if (!IsNestedGrad()) { PopGraphStack(); + PopCurrentCellOpInfoFromStack(); + ReplaceCellOpInfoByCellId(cell_id); } return true; } return false; } +bool PynativeExecutor::CheckCellChanged(const std::string &cell_id) { + bool res = false; + if (CheckRealDynamicCell(cell_id)) { + MS_LOG(DEBUG) << "Cur cell " << cell_id << " is dynamic, no need check"; + return true; + } + if (GetCellOpInfo().empty()) { + MS_LOG(DEBUG) << "Cell op info is empty"; + return true; + } + auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), + [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); + if (it == cell_graph_list_.end() || IsFirstGradStep(top_cell_id_)) { + return true; + } + MS_LOG(DEBUG) << "Cell op info " << GetCellOpInfo() << ", old " << (*it)->cell_ops_info.at((*it)->call_times); + if ((*it)->cell_ops_info.at((*it)->call_times) != GetCellOpInfo()) { + res = true; + UpdateCellDynamic(cell_id); + MS_LOG(DEBUG) << "Cell self changed"; + } + (*it)->call_times = (*it)->call_times < (*it)->cell_ops_info.size() - 1 ? (*it)->call_times + 1 : 0; + return res; +} + +void PynativeExecutor::UpdateCellDynamic(const std::string &cell_id) { + for (auto &it : cell_graph_list_) { + if (it->cell_id != cell_id) { + it->is_real_dynamic = true; + continue; + } + it->is_real_dynamic = true; + break; + } +} + void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned, bool is_grad) { + auto update_in_endgraph = need_cloned && !is_grad; if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { // Bprop just save backward graph auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), - [&cell_id](const CellInfo &value) { return value.cell_id == cell_id; }); + [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); if (it != cell_graph_list_.end()) { - it->is_grad = is_grad; - it->fg = g; + (*it)->is_grad = is_grad; + if (g != (*it)->fg) { + graph_info_map_.update((*it)->fg, g); + (*it)->fg = g; + } + if (update_in_endgraph && IsFirstGradStep(top_cell_id_)) { + (*it)->cell_ops_info.emplace_back(GetCellOpInfo()); + } MS_LOG(DEBUG) << "Update bprop bg cell id " << cell_id; } else { py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME); auto bprop_func_cell_id = GetId(bprop_func); - MS_LOG(DEBUG) << "Add new bprop cell_id " << cell_id << " bprop func cell id " << bprop_func_cell_id; - auto cell_info = CellInfo(false, true, g, cell_id, bprop_func_cell_id); + MS_LOG(DEBUG) << "Add new bprop cell_id " << cell_id << " bprop func cell id " << bprop_func_cell_id + << " cell ops info " << GetCellOpInfo(); + auto cell_info = std::make_shared(true, has_dynamic_cell_, g, cell_id, bprop_func_cell_id); + cell_info->cell_ops_info.emplace_back(GetCellOpInfo()); cell_graph_list_.insert(cell_graph_list_.begin(), cell_info); } return; } + FuncGraphPtr tmp = g; - if (need_cloned && !IsNotNestedGrad()) { - auto cloned_curr_g = BasicClone(g); - graph_info_map_[cloned_curr_g] = graph_info_map_.at(g); - tmp = cloned_curr_g; - MS_LOG(DEBUG) << "Replace cur graph " << g.get() << " with cloned new " << cloned_curr_g.get(); + if (!IsFirstGradStep(top_cell_id_) && CheckDynamicCell(cell_id) && !CheckRealDynamicCell(cell_id)) { + MS_LOG(DEBUG) << "No need cloned"; + need_cloned = false; } - for (auto &it : cell_graph_list_) { - if (it.cell_id != cell_id) { - continue; + auto clone_fn = [&g, &tmp, need_cloned, this]() { + if (!need_cloned) { + return; } - it.is_grad = is_grad; - if (need_cloned) { - it.fg = tmp; - } - if (!need_cloned && !is_grad) { - graph_info_map_[g] = graph_info_map_.at(it.fg); - graph_info_map_.erase(it.fg); - it.fg = g; - MS_LOG(DEBUG) << "Replace cur graph " << it.fg.get() << " with new " << g.get(); + tmp = BasicClone(g); + graph_info_map_.update(g, tmp); + ClearCnodeRes(tmp->output()); + }; + // First call or cell id not exist + if (update_in_endgraph && (IsFirstGradStep(top_cell_id_) || !CheckCellGraph(cell_id))) { + if (!CheckCellGraph(cell_id)) { + clone_fn(); + MS_LOG(DEBUG) << "Add new cell with cloned graph " << cell_id << " cell ops info " << GetCellOpInfo(); + auto cell_info = std::make_shared(true, has_dynamic_cell_, tmp, cell_id, ""); + cell_info->cell_ops_info.emplace_back(GetCellOpInfo()); + cell_graph_list_.insert(cell_graph_list_.begin(), cell_info); + } else { + auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), + [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); + if (it != cell_graph_list_.end()) { + (*it)->cell_ops_info.emplace_back(GetCellOpInfo()); + } + MS_LOG(DEBUG) << "Add another same cell ops info"; } return; } - MS_LOG(DEBUG) << "Add new cell graph " << cell_id; - auto cell_info = CellInfo(false, true, tmp, cell_id, ""); - cell_graph_list_.insert(cell_graph_list_.begin(), cell_info); + + for (auto &it : cell_graph_list_) { + if (it->cell_id != cell_id) { + continue; + } + if (IsFirstGradStep(cell_id)) { + // no compute grad + it->is_grad = is_grad; + } + if (need_cloned) { + clone_fn(); + if (it->fg != nullptr) { + graph_info_map_.erase(it->fg); + } + MS_LOG(DEBUG) << "Update cur graph " << it->fg.get() << " with cloned new " << tmp.get(); + it->fg = tmp; + } + if (!need_cloned && !is_grad) { + graph_info_map_.erase(it->fg); + MS_LOG(DEBUG) << "Update cur graph " << it->fg.get() << " with new " << tmp.get(); + it->fg = tmp; + } + break; + } +} + +void PynativeExecutor::ClearCnodeRes(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return; + } + auto cnode = node->cast(); + cnode->clear_inputs_value(); + for (size_t i = 0; i < cnode->size(); ++i) { + auto n = cnode->input(i); + cnode->set_forward(nullptr, ""); + ClearCnodeRes(n); + } } FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r, @@ -2092,7 +2247,7 @@ FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncG DumpGraphIR("fg.ir", g); auto is_top = IsTopGraph(cell_id); MS_LOG(DEBUG) << "Grad top cell " << is_top; - set_need_replace_forward(IsNotNestedGrad()); + set_need_replace_forward(!IsNestedGrad()); // Obtain grad graph auto newfg = ad::Grad(g, r, is_top); @@ -2171,9 +2326,11 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje py::object sens = py::none(); py::object forward_args = args; const auto &cell_id = GetGradCellId(grad->sens_param(), cell, args, &forward_args, &sens); - MS_LOG(DEBUG) << "GradNet start " << args.size() << " " << cell_id; - const auto &sw_changed = CheckCellChanged(cell_id, weights, sens); - if (!dynamic_cell_ && !sw_changed.second && CheckCellGraph(cell_id, true)) { + MS_LOG(DEBUG) << "GradNet start " << size << " " << cell_id; + const auto ¶ms_changed = CheckGradParamsChanged(cell_id, weights, sens); + if (!params_changed && !IsFirstGradStep(cell_id) && !CheckRealDynamicCell(cell_id)) { + UpdateTopCellCompileInfo(cell_id, false); + ClearDynamicTopRes(cell_id); MS_LOG(INFO) << "Gradgraph already compiled"; return; } @@ -2206,20 +2363,54 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje resource->results()[pipeline::kBackend] = compile::CreateBackend(); MS_LOG(INFO) << "Start opt"; - if (dynamic_cell_) { + if (has_dynamic_cell_) { SaveAllValueNodeTensors(resource->func_graph()); } PynativeOptimizeAction(resource); + DumpGraphIR("after_opt.ir", resource->func_graph()); SaveTensorsInValueNode(resource); TaskEmitAction(resource); ExecuteAction(resource); + ClearUselessRes(df_builder, cell, cell_id); UpdateCellGraph(cell, curr_g_, cell_id, false, true); - UpdateGraphInfoMap(cell_id); + UpdateTopCellCompileInfo(cell_id, true); resource->Clean(); } -std::pair PynativeExecutor::CheckCellChanged(const std::string &cell_id, const py::object &weights, - const py::object &sens) { +void PynativeExecutor::ClearDynamicTopRes(const std::string &cell_id) { + if (IsTopestGraph(cell_id)) { + op_index_map_.clear(); + } + // Delete unused top cell resource + if (!CheckDynamicCell(cell_id)) { + return; + } + int same_top_cell_count = 0; + for (auto it = top_cell_list_.begin(); it != top_cell_list_.end();) { + if ((*it)->cell_id == cell_id) { + ++same_top_cell_count; + if (same_top_cell_count > 1) { + graph_info_map_.erase((*it)->df_builder); + it = top_cell_list_.erase(it); + --same_top_cell_count; + } else { + ++it; + } + } else { + ++it; + } + } +} + +bool PynativeExecutor::CheckGradParamsChanged(const std::string &cell_id, const py::object &weights, + const py::object &sens) { + bool res = false; + auto it = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), + [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id; }); + if (it == top_cell_list_.end()) { + return res; + } + auto fn = [](const py::object &arg) { std::string arg_id; if (py::isinstance(arg)) { @@ -2240,31 +2431,26 @@ std::pair PynativeExecutor::CheckCellChanged(const std::string &cell sens_id = fn(sens); } + if (!(*it)->sens_id.empty() && (*it)->sens_id != sens_id) { + (*it)->sens_id = sens_id; + } std::string weights_id = fn(weights); - std::pair sens_weights_changed(false, false); - // Check whether sens or weights changed - auto it = cell_sw_map_.find(cell_id); - if (it != cell_sw_map_.end() && it->second.first != sens_id) { - MS_LOG(DEBUG) << "Sens_id, cur is " << it->second.first << " new is " << sens_id; - sens_weights_changed.first = true; + if (!(*it)->weights_id.empty() && (*it)->weights_id != weights_id) { + (*it)->weights_id = weights_id; + res = true; } - if (it != cell_sw_map_.end() && it->second.second != weights_id) { - MS_LOG(DEBUG) << "Weights_id, cur is " << it->second.first << " new is " << weights_id; - sens_weights_changed.second = true; - } - cell_sw_map_[cell_id] = std::make_pair(sens_id, weights_id); - return sens_weights_changed; + return res; } void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id) { if (IsTopGraph(cell_id)) { - return; + VectorClear>(&top_cell_list_, cell_id); } ResourcePtr resource = nullptr; auto ia = std::find_if(top_cell_list_.begin(), top_cell_list_.end(), - [&cell_id](const TopCellInfo &value) { return value.cell_id == cell_id; }); + [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id; }); if (ia != top_cell_list_.end()) { - resource = GetResource(ia->cell_id); + resource = GetResource((*ia)->cell_id); MS_EXCEPTION_IF_NULL(resource); MS_LOG(DEBUG) << "Find old resource " << resource.get(); } @@ -2275,15 +2461,15 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args } MS_EXCEPTION_IF_NULL(resource); FuncGraphPtr df_builder = std::make_shared(); - GraphInfo graph_info = GraphInfo(cell_id); - graph_info_map_.emplace(df_builder, graph_info); - auto top_cell_info = TopCellInfo(resource, df_builder, nullptr, cell_id); + auto graph_info = std::make_shared(cell_id); + graph_info_map_[df_builder] = graph_info; + auto top_cell_info = std::make_shared(false, resource, df_builder, cell_id); top_cell_list_.emplace_back(top_cell_info); FuncGraphPtr forward_graph = nullptr; auto ib = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), - [&cell_id](const CellInfo &value) { return value.cell_id == cell_id; }); + [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); if (ib != cell_graph_list_.end()) { - forward_graph = ib->fg; + forward_graph = (*ib)->fg; } MS_EXCEPTION_IF_NULL(forward_graph); if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { @@ -2295,32 +2481,59 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args } // Copy weights parameters + ReplaceGraphParams(df_builder, forward_graph, cell_id); resource->manager()->AddFuncGraph(forward_graph); - auto manager = Manage({forward_graph}, false); - for (const auto &it : graph_info_map_.at(forward_graph).params) { - if (!it.second->has_default()) { - continue; - } - auto new_param = df_builder->add_parameter(); - new_param->set_abstract(it.second->abstract()); - new_param->set_name(it.second->name()); - new_param->set_default_param(it.second->default_param()); - ScopePtr scope = (it.second->scope() != kDefaultScope) ? it.second->scope() : kDefaultScope; - new_param->set_scope(scope); - manager->Replace(it.second, new_param); - replace_weights_map_[forward_graph].emplace_back(std::make_pair(it.second, new_param)); - MS_LOG(DEBUG) << "Old param ptr " << it.second.get() << " name " << it.second->name(); - - graph_info_map_.at(df_builder).params[it.first] = new_param; - SetParamNodeMapInGraphInfoMap(df_builder, it.first, new_param); - SetNodeMapInGraphInfoMap(df_builder, it.first, new_param); - } DumpGraphIR("nested_fg.ir", forward_graph); set_need_replace_forward(false); auto newfg = MakeGradGraph(cell, forward_graph, resource, cell_id, args); resource->set_func_graph(newfg); } +void PynativeExecutor::ReplaceGraphParams(const FuncGraphPtr &df_builder, const FuncGraphPtr &forward_graph, + const std::string &cell_id) { + std::vector graph_before{}; + bool index_find = false; + for (const auto &it : cell_graph_list_) { + if (IsBpropGraph(it->cell_id) || it->fg == nullptr) { + continue; + } + if (index_find) { + graph_before.emplace_back(it->fg); + continue; + } + if (it->cell_id == cell_id) { + index_find = true; + graph_before.emplace_back(it->fg); + } + } + + auto manager = Manage({forward_graph}, false); + for (const auto &f : graph_before) { + auto graph_info = graph_info_map_.at(f); + MS_EXCEPTION_IF_NULL(graph_info); + for (const auto &it : graph_info->params) { + if (!it.second->has_default()) { + continue; + } + auto new_param = df_builder->add_parameter(); + new_param->set_abstract(it.second->abstract()); + new_param->set_name(it.second->name()); + new_param->set_default_param(it.second->default_param()); + ScopePtr scope = (it.second->scope() != kDefaultScope) ? it.second->scope() : kDefaultScope; + new_param->set_scope(scope); + manager->Replace(it.second, new_param); + replace_weights_map_[forward_graph].emplace_back(std::make_pair(it.second, new_param)); + MS_LOG(DEBUG) << "Param name " << new_param->name() << " ptr " << new_param.get(); + + auto graph_info_of_df_builder = graph_info_map_.at(df_builder); + MS_EXCEPTION_IF_NULL(graph_info_of_df_builder); + graph_info_of_df_builder->params[it.first] = new_param; + SetParamNodeMapInGraphInfoMap(df_builder, it.first, new_param); + SetNodeMapInGraphInfoMap(df_builder, it.first, new_param); + } + } +} + void PynativeExecutor::SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size) { std::vector new_params; for (size_t i = 0; i < size; i++) { @@ -2347,9 +2560,11 @@ std::vector PynativeExecutor::GetWeightsArgs(const py::object &weigh auto param = tuple[it]; auto param_id = GetId(param); AnfNodePtr para_node = nullptr; - if (graph_info_map_.at(df_builder).params.find(param_id) != graph_info_map_.at(df_builder).params.end() && - graph_info_map_.at(df_builder).node_map.find(param_id) != graph_info_map_.at(df_builder).node_map.end()) { - para_node = graph_info_map_.at(df_builder).node_map[param_id].first; + auto graph_info = graph_info_map_.at(df_builder); + MS_EXCEPTION_IF_NULL(graph_info); + if (graph_info->params.find(param_id) != graph_info->params.end() && + graph_info->node_map.find(param_id) != graph_info->node_map.end()) { + para_node = graph_info->node_map[param_id].first; } else { auto name_attr = parse::python_adapter::GetPyObjAttr(param, "name"); if (py::isinstance(name_attr)) { @@ -2407,9 +2622,9 @@ void PynativeExecutor::GradGraph(const FuncGraphPtr &g, const GradOperationPtr & const std::vector &weights, size_t arg_size, const std::string &cell_id) { FuncGraphPtr top_g = nullptr; auto it = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), - [&cell_id](const CellInfo &value) { return value.cell_id == cell_id; }); + [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); if (it != cell_graph_list_.end()) { - top_g = it->fg; + top_g = (*it)->fg; } MS_EXCEPTION_IF_NULL(top_g); auto nparam = top_g->parameters().size(); @@ -2438,22 +2653,35 @@ void PynativeExecutor::GradGraph(const FuncGraphPtr &g, const GradOperationPtr & resource->manager()->AddFuncGraph(df_builder); } -void PynativeExecutor::UpdateGraphInfoMap(const std::string &cell_id) { +void PynativeExecutor::ClearUselessRes(const FuncGraphPtr &df_builder, const py::object &cell, + const std::string &cell_id) { + graph_info_map_.erase(df_builder); + bool has_custom_bprop = py::hasattr(cell, parse::CUSTOM_BPROP_NAME); + bool is_dynamic_top_fist_grad = CheckDynamicCell(cell_id) && IsFirstGradStep(cell_id); + bool is_topmost = IsTopestGraph(cell_id) && top_cell_list_.front()->cell_id == cell_id; + if (has_custom_bprop || is_dynamic_top_fist_grad || !is_topmost) { + return; + } + + MS_LOG(DEBUG) << "Update topmost cell graph list and graph info map"; + // Clear graph_info_map_ std::vector l{}; bool index_find = false; - for (const auto &it : cell_graph_list_) { + for (auto &it : cell_graph_list_) { if (index_find) { - l.emplace_back(it.cell_id); + it->fg = nullptr; + l.emplace_back(it->cell_id); continue; } - if (it.cell_id == cell_id) { + if (it->cell_id == cell_id) { index_find = true; - l.emplace_back(it.cell_id); + it->fg = nullptr; + l.emplace_back(it->cell_id); } } for (const auto &it : l) { for (auto ic = graph_info_map_.begin(); ic != graph_info_map_.end();) { - if (ic->second.cell_id.find(it) != std::string::npos) { + if (ic->second->cell_id.find(it) != std::string::npos) { ic = graph_info_map_.erase(ic); } else { ++ic; @@ -2470,14 +2698,15 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args & return BaseRefToPyData(ret); } const auto &cell_id = GetCellId(cell, args); - string key = cell_id.substr(0, std::min(PTR_LEN, cell_id.size())); + std::string key = cell_id.substr(0, std::min(PTR_LEN, cell_id.size())); MS_LOG(DEBUG) << "Key is " << key; for (auto it = cell_graph_list_.begin(); it != cell_graph_list_.end(); ++it) { - MS_LOG(DEBUG) << "Cur cell id " << it->cell_id; - if (key != it->cell_id.substr(0, std::min(PTR_LEN, it->cell_id.size()))) { + MS_LOG(DEBUG) << "Cur cell id " << (*it)->cell_id; + if (key != (*it)->cell_id.substr(0, std::min(PTR_LEN, (*it)->cell_id.size()))) { continue; } MS_LOG(DEBUG) << "Delete cellid from cell graph list"; + graph_info_map_.erase((*it)->fg); cell_graph_list_.erase(it); ret = true; break; @@ -2485,12 +2714,19 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args & return BaseRefToPyData(ret); } +py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::args &args) { + const auto &cell_id = GetCellId(cell, args); + bool already_run = CheckCellGraph(cell_id); + MS_LOG(DEBUG) << "Graph have already run " << already_run << " cell id " << cell_id; + return BaseRefToPyData(already_run); +} + py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args, const py::object &phase) { auto cell_id = GetCellId(cell, args); MS_LOG(DEBUG) << "Run start cell id " << cell_id; bool has_sens = false; for (const auto &it : top_cell_list_) { - if (cell_id.find(it.cell_id) != std::string::npos && cell_id != it.cell_id) { + if (cell_id.find(it->cell_id) != std::string::npos && cell_id != it->cell_id) { has_sens = true; break; } @@ -2521,10 +2757,15 @@ py::object PynativeExecutor::Run(const py::object &cell, const py::tuple &args, set_grad_runing(false); MS_LOG(DEBUG) << "Eval run end " << value.ToString(); auto out = BaseRefToPyData(value); - if (MakeBpropNestedCnode(cell, out, cell_id)) { - return out; + auto do_vm_compiled = + std::any_of(top_cell_list_.begin(), top_cell_list_.end(), + [&cell_id](const TopCellInfoPtr &value) { return value->cell_id == cell_id && value->do_vm_compiled; }); + if (do_vm_compiled) { + if (MakeBpropNestedCnode(cell, out, cell_id)) { + return out; + } + MakeNestedCnode(cell_id, args, resource, out, has_sens); } - MakeNestedCnode(cell_id, args, resource, out, has_sens); return out; } @@ -2537,7 +2778,9 @@ bool PynativeExecutor::MakeBpropNestedCnode(const py::object &cell, const py::ob std::vector inputs; inputs.emplace_back(NewValueNode(curr_g_)); PopGraphStack(); - for (const auto &ig : graph_info_map_.at(curr_g_).params) { + auto graph_info = graph_info_map_.at(curr_g_); + MS_EXCEPTION_IF_NULL(graph_info); + for (const auto &ig : graph_info->params) { if (!ig.second->has_default()) { inputs.emplace_back(ig.second); } @@ -2569,8 +2812,8 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg for (size_t i = 0; i < inputs_size; ++i) { inputs.emplace_back(GetInput(args[i], false)); } - if (newfg->parameters().size() > inputs_size) { - SetNestedWeightsParam(newfg, cell_id, &inputs); + if (newfg->parameters().size() > args.size()) { + RecoverGraphParams(newfg, cell_id, &inputs); } auto out_id = GetId(out); auto cnode = graph_prev->NewCNode(inputs); @@ -2579,15 +2822,16 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString(4); } -void PynativeExecutor::SetNestedWeightsParam(const FuncGraphPtr &newfg, const std::string &cell_id, - std::vector *inputs) { +void PynativeExecutor::RecoverGraphParams(const FuncGraphPtr &newfg, const std::string &cell_id, + std::vector *inputs) { FuncGraphPtr forward_graph = nullptr; auto ic = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(), - [&cell_id](const CellInfo &value) { return value.cell_id == cell_id; }); + [&cell_id](const CellInfoPtr &value) { return value->cell_id == cell_id; }); if (ic != cell_graph_list_.end()) { - forward_graph = ic->fg; + forward_graph = (*ic)->fg; } MS_EXCEPTION_IF_NULL(forward_graph); + auto param_list = replace_weights_map_.at(forward_graph); auto params = newfg->parameters(); auto manage = Manage({newfg}, false); for (const auto &it : params) { @@ -2595,15 +2839,12 @@ void PynativeExecutor::SetNestedWeightsParam(const FuncGraphPtr &newfg, const st if (!param->has_default()) { continue; } - auto ir = replace_weights_map_.find(forward_graph); - if (ir == replace_weights_map_.end()) { - MS_LOG(EXCEPTION) << "Not find forward_graph in repalce weigths map"; - } - for (const auto &ip : ir->second) { - MS_LOG(DEBUG) << "Get param name " << param->name() << " cache name " << ip.second->name(); - if (ip.second->name() == param->name()) { - manage->Replace(param, ip.first); - inputs->emplace_back(ip.first); + for (auto p = param_list.begin(); p != param_list.end();) { + MS_LOG(DEBUG) << "Param name " << param->name() << " ptr " << param.get(); + if (p->second->name() == param->name()) { + manage->Replace(param, p->first); + inputs->emplace_back(p->first); + param_list.erase(p); break; } } @@ -2619,7 +2860,7 @@ void PynativeExecutor::Clear(const std::string &cell_id) { MS_LOG(DEBUG) << "Clear cell res, cell id " << cell_id; for (auto it = graph_info_map_.begin(); it != graph_info_map_.end();) { - if (it->second.cell_id.find(cell_id) != std::string::npos) { + if (it->second->cell_id.find(cell_id) != std::string::npos) { it = graph_info_map_.erase(it); } else { ++it; @@ -2631,10 +2872,8 @@ void PynativeExecutor::Clear(const std::string &cell_id) { ms_context->set_param(MS_CTX_ENABLE_PYNATIVE_INFER, false); } ConfigManager::GetInstance().ResetIterNum(); - MapClear>(&cell_dynamic_map_, cell_id); - MapClear>>(&cell_sw_map_, cell_id); - VectorClear>(&cell_graph_list_, cell_id); - VectorClear>(&top_cell_list_, cell_id); + VectorClear>(&cell_graph_list_, cell_id); + VectorClear>(&top_cell_list_, cell_id); node_abs_map_.clear(); } @@ -2653,20 +2892,18 @@ void PynativeExecutor::ClearRes() { graph_id_ = 0; grad_order_ = 0; grad_flag_ = false; - dynamic_cell_ = false; + has_dynamic_cell_ = false; grad_is_running_ = false; need_replace_forward_ = true; curr_g_ = nullptr; graph_info_map_.clear(); - cell_sw_map_.clear(); replace_weights_map_.clear(); cell_graph_list_.clear(); top_cell_list_.clear(); op_index_map_.clear(); cell_op_index_with_tensor_id_.clear(); cell_tensor_id_with_tensor_.clear(); - cell_dynamic_map_.clear(); prim_abs_list_.clear(); std::stack().swap(graph_stack_); } @@ -2720,6 +2957,7 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { .def("new_graph", &PynativeExecutor::NewGraph, "pynative new a graph.") .def("end_graph", &PynativeExecutor::EndGraph, "pynative end a graph.") .def("check_graph", &PynativeExecutor::CheckGraph, "pynative check a grad graph.") + .def("check_run", &PynativeExecutor::CheckAlreadyRun, "pynative check graph run before.") .def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.") .def("clear", &PynativeExecutor::Clear, "pynative clear status.") .def("sync", &PynativeExecutor::Sync, "pynative sync stream.") diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index eb330fabd2f..4bbb0575947 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -69,32 +69,47 @@ struct GraphInfo { explicit GraphInfo(std::string id) : cell_id(std::move((id))) {} }; -struct CellInfo { - bool is_grad{false}; // Derivative is calculated - bool is_custom_bprop{false}; // Custom bprop - FuncGraphPtr fg; // Forward graph - std::string cell_id; - std::string bprop_cell_id; +class CellInfo { + public: CellInfo() = default; - CellInfo(bool isgrad, bool custom_bprop, FuncGraphPtr foward_graph, std::string cellid, std::string bprop_id) - : is_grad(isgrad), - is_custom_bprop(custom_bprop), + CellInfo(bool custom_bprop, bool has_dynamic, FuncGraphPtr foward_graph, std::string cellid, std::string bprop_id) + : is_custom_bprop(custom_bprop), + is_dynamic(has_dynamic), fg(std::move(foward_graph)), cell_id(std::move(cellid)), bprop_cell_id(std::move(bprop_id)) {} + + bool is_grad{false}; // Derivative is calculated + bool is_custom_bprop{false}; // Custom bprop + bool is_dynamic{false}; // Set by has_dynamic_cell + bool is_real_dynamic{false}; // Set by ops order + size_t call_times{0}; + FuncGraphPtr fg{nullptr}; // Forward graph + std::string cell_id; + std::string bprop_cell_id; + std::vector cell_ops_info; // All ops info }; -struct TopCellInfo { - ResourcePtr resource; - FuncGraphPtr df_builder; - FuncGraphPtr bg; // Backward graph - std::string cell_id; - bool is_dynamic_cell{false}; +class TopCellInfo { + public: TopCellInfo() = default; - TopCellInfo(ResourcePtr r, FuncGraphPtr df, FuncGraphPtr backward_graph, std::string cellid) - : resource(std::move(r)), df_builder(std::move(df)), bg(std::move(backward_graph)), cell_id(std::move(cellid)) {} + TopCellInfo(bool topest, ResourcePtr r, FuncGraphPtr df, std::string cellid) + : is_topest(topest), resource(std::move(r)), df_builder(std::move(df)), cell_id(std::move(cellid)) {} + + bool is_topest{false}; + bool do_vm_compiled{false}; + ResourcePtr resource{nullptr}; + FuncGraphPtr df_builder{nullptr}; + FuncGraphPtr bg{nullptr}; // Backward graph + std::string cell_id; + std::string sens_id; + std::string weights_id; }; +using GraphInfoPtr = std::shared_ptr; +using CellInfoPtr = std::shared_ptr; +using TopCellInfoPtr = std::shared_ptr; + class PynativeExecutor : public std::enable_shared_from_this { public: static std::shared_ptr GetInstance() { @@ -119,11 +134,12 @@ class PynativeExecutor : public std::enable_shared_from_this { void NewGraph(const py::object &cell, const py::args &args); py::object Run(const py::object &cell, const py::tuple &args, const py::object &phase); py::object CheckGraph(const py::object &cell, const py::args &args); + py::object CheckAlreadyRun(const py::object &cell, const py::args &args); void EndGraph(const py::object &cell, const py::object &out, const py::args &args); void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); // Get info - bool GetIsDynamicCell() const { return dynamic_cell_; } + bool GetIsDynamicCell() { return CheckRealDynamicCell(top_cell_id_); } // Call by python void Clear(const std::string &flag = ""); void Clean(); @@ -149,7 +165,7 @@ class PynativeExecutor : public std::enable_shared_from_this { template void VectorClear(T *vec, const std::string &cell_id) { for (auto it = vec->begin(); it != vec->end();) { - if (it->cell_id.find(cell_id) != std::string::npos) { + if ((*it)->cell_id.find(cell_id) != std::string::npos) { it = vec->erase(it); } else { it++; @@ -201,29 +217,39 @@ class PynativeExecutor : public std::enable_shared_from_this { void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real); void SaveTensorsInValueNode(const ResourcePtr &resource); void SaveAllValueNodeTensors(const FuncGraphPtr &graph); - void CleanPreMemoryInValueNode(const std::string &cell_id); + void CleanPreMemoryInValueNode(); // Construct grad graph void PushCurrentGraphToStack(); void PopGraphStack(); + void PushCurrentCellOpInfoToStack(); + void PopCurrentCellOpInfoFromStack(); FuncGraphPtr GetDfbuilder(const std::string &cell_id = ""); ResourcePtr GetResource(const std::string &cell_id = ""); void AddNestedGradOrder() { ++grad_order_; } void SubNestedGradOrder(); - bool IsNotNestedGrad() const; + bool IsNestedGrad() const; bool IsTopGraph(const std::string &cell_id); + bool IsTopestGraph(const std::string &cell_id); bool IsBpropGraph(const std::string &cell_id); + bool IsFirstGradStep(const std::string &cell_id); bool grad_running() const { return grad_is_running_; } void set_grad_runing(bool grad_runing) { grad_is_running_ = grad_runing; } void set_need_replace_forward(bool need_replace_forward) { need_replace_forward_ = need_replace_forward; } bool need_construct_graph() { return !graph_stack_.empty() && grad_flag_; } bool CheckCellGraph(const std::string &cell_id, bool is_grad = false); + bool CheckDynamicCell(const std::string &cell_id); + bool CheckRealDynamicCell(const std::string &cell_id); void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned = false, bool is_grad = false); + void ClearCnodeRes(const AnfNodePtr &node); + void UpdateCellDynamic(const std::string &cell_id); + bool CheckCellChanged(const std::string &cell_id); + void UpdateTopCellCompileInfo(const std::string &cell_id, bool vm_compiled); void ClearResidualRes(const std::string &cell_id); void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph); void NewGraphInner(const py::object &cell, const py::args &args); - void MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g); + void MakeNewTopGraph(const string &cell_id, const py::args &args); void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args); void EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out, const std::string &out_id, const py::args &args); @@ -232,38 +258,44 @@ class PynativeExecutor : public std::enable_shared_from_this { const std::string &cell_id, const py::args &args); std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args, py::object *forward_args, py::object *sens = nullptr); + void ClearDynamicTopRes(const std::string &cell_id); void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); std::string GetCellId(const py::object &obj, const py::args &args); - std::pair CheckCellChanged(const std::string &cell_id, const py::object &weights, const py::object &sens); + std::string GetTensorCellId(const std::string &cell_id); + bool CheckGradParamsChanged(const std::string &cell_id, const py::object &weights, const py::object &sens); void SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size); void GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, const std::vector &weights, size_t arg_size, const std::string &cell_id); std::vector GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder); abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder); - void UpdateGraphInfoMap(const std::string &cell_id); + void ClearUselessRes(const FuncGraphPtr &df_builder, const py::object &cell, const std::string &cell_id); + void ReplaceGraphParams(const FuncGraphPtr &df_builder, const FuncGraphPtr &forward_graph, + const std::string &cell_id); void SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id); void MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource, const py::object &out, bool has_sens); - void SetNestedWeightsParam(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector *inputs); + void RecoverGraphParams(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector *inputs); bool MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id); // Hold graph(forward and grad) info + std::string GetCellOpInfo(); + void ReplaceCellOpInfoByCellId(const std::string &cell_id); void SetPyObjInGraphInfoMap(const FuncGraphPtr &g, const std::string &obj) { - graph_info_map_[g].objects.push_back(obj); + graph_info_map_[g]->objects.push_back(obj); } void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node, bool is_param = false); void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr ¶m) { - graph_info_map_[g].params[id] = param; + graph_info_map_[g]->params[id] = param; } void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node, int64_t index = -1) { - graph_info_map_[g].node_map[id] = std::make_pair(node, std::vector{index}); + graph_info_map_[g]->node_map[id] = std::make_pair(node, std::vector{index}); } void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node, const std::vector &index) { - graph_info_map_[g].node_map[id] = std::make_pair(node, index); + graph_info_map_[g]->node_map[id] = std::make_pair(node, index); } void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node, const std::vector &index_sequence, bool is_param = false); @@ -274,7 +306,7 @@ class PynativeExecutor : public std::enable_shared_from_this { size_t grad_order_{0}; std::string top_cell_id_; bool grad_flag_{false}; - bool dynamic_cell_{false}; + bool has_dynamic_cell_{false}; bool grad_is_running_{false}; bool need_replace_forward_{true}; // The pointer of top python Cell object, which is always the network(inherit class Cell) ran in python test script, @@ -288,16 +320,15 @@ class PynativeExecutor : public std::enable_shared_from_this { FuncGraphPtr curr_g_{nullptr}; // Records forwrad graph, the bottom is top graph std::stack graph_stack_; + // Records op info of every cell, the bottom is op info of top cell + std::stack cell_op_info_stack_; // Use vector for keep order - std::vector cell_graph_list_; - std::vector top_cell_list_; + std::vector cell_graph_list_; + std::vector top_cell_list_; std::unordered_set cell_input_args_; - std::unordered_map cell_dynamic_map_; // Record all info for all cells - std::unordered_map graph_info_map_; - // key: cell_id, value: (send_id, weighs_id), cache for sens and weight change - std::unordered_map> cell_sw_map_; + OrderedMap graph_info_map_; std::unordered_map>> replace_weights_map_; // Used for runop and replace forward result of grad graph diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 4b0a94b0fec..4af7e483cd0 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -314,6 +314,9 @@ class _PynativeExecutor: def check_graph(self, obj, *args, **kwargs): return self._executor.check_graph(obj, *args, *(kwargs.values())) + def check_run(self, obj, *args, **kwargs): + return self._executor.check_run(obj, *args, *(kwargs.values())) + def grad(self, grad, obj, weights, *args, **kwargs): self._executor.grad_net(grad, obj, weights, *args, *(kwargs.values())) diff --git a/mindspore/core/utils/ordered_map.h b/mindspore/core/utils/ordered_map.h index 1152990e461..d33817ed237 100644 --- a/mindspore/core/utils/ordered_map.h +++ b/mindspore/core/utils/ordered_map.h @@ -162,6 +162,14 @@ class OrderedMap { return pos == map_data_.end() ? sequential_data_.end() : (pos->second); } + ValueT at(const key_t &key) { + auto pos = map_data_.find(key); + if (pos == map_data_.end()) { + MS_LOG(EXCEPTION) << "Have no key " << key; + } + return pos->second->second; + } + // Remove the last element from the sequential_data_. void pop_back() { typename map_type::iterator pos = map_data_.find(sequential_data_.back().first); @@ -192,6 +200,24 @@ class OrderedMap { return 1; } + void update(const key_t &old_key, const key_t &new_key) { + auto old_it = find(old_key); + if (old_it == end()) { + return; + } + auto new_it = find(new_key); + if (new_it == end()) { + old_it->first = new_key; + auto nh = map_data_.extract(old_key); + nh.key() = new_key; + map_data_.insert(std::move(nh)); + return; + } + *old_it = *new_it; + (void)erase(old_key); + (void)erase(new_key); + } + private: map_type map_data_; sequential_type sequential_data_; diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index da7c5137eec..301126504c0 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -68,7 +68,7 @@ class Cell(Cell_): """ IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names', '_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run', - '_parameter_layout_dict', '_already_run', '_params_list', '_tensor_list', '_phase', + '_parameter_layout_dict', '_params_list', '_tensor_list', '_phase', '_auto_parallel_mode', '_backward_hook', '_bprop_debug', '_is_run', '_param_prefix', '_attr_synced', 'enable_hook', 'pynative', 'requires_grad', '_auto_parallel_compile_and_run', 'cell_type'] @@ -105,15 +105,10 @@ class Cell(Cell_): self._backward_hook = None self.enable_hook = False self._bprop_debug = False - self._already_run = False self.cell_type = None self._auto_parallel_compile_and_run = False self._support_non_tensor_inputs = False - @property - def already_run(self): - return self._already_run - def __getstate__(self): base = Cell_.__getstate__(self) return base, self.__dict__ @@ -150,10 +145,6 @@ class Cell(Cell_): # `` to `xxxxxxx` return str(self.__class__)[8:-2] - @already_run.setter - def already_run(self, value): - self._already_run = value - @property def create_time(self): return self._create_time @@ -334,12 +325,10 @@ class Cell(Cell_): for item in inputs: if isinstance(item, numpy.ndarray): raise TypeError("cell inputs should not be numpy array.") - origin_grad = [] if self.requires_grad is True: _pynative_exec.set_grad_flag(True) _pynative_exec.new_graph(self, *inputs, **kwargs) for cell in self.cells(): - origin_grad.append(cell.requires_grad) cell.set_grad(True) else: _pynative_exec.set_grad_flag(False) @@ -363,9 +352,6 @@ class Cell(Cell_): output = output.data if self.requires_grad is True: _pynative_exec.end_graph(self, output, *inputs, **kwargs) - for i, cell in enumerate(self.cells()): - cell.set_grad(origin_grad[i]) - self._already_run = True return output def _add_attr(self, name, value): diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index 14a469dab58..3ceb69fac05 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -319,36 +319,30 @@ class GradOperation(GradOperation_): GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param) self.grad_fn = None self.fn = None - self.need_forward = False def _pynative_forward_run(self, args, kwargs, fn): """ Pynative forward run to build grad graph. """ - new_kwargs = {} + new_kwargs = kwargs if self.sens_param: if not 'sens' in kwargs.keys(): args = args[:-1] - new_kwargs = kwargs else: - for key, value in kwargs.items(): - if key != 'sens': - new_kwargs[key] = value + new_kwargs = kwargs.copy() + new_kwargs.pop('sens') for arg in args: if not isinstance(arg, Tensor): raise TypeError("grad inputs should be tensor in pynative mode") if isinstance(fn, FunctionType): - _pynative_exec.set_grad_flag(True) - _pynative_exec.new_graph(fn, *args, **new_kwargs) - output = fn(*args, **new_kwargs) - _pynative_exec.end_graph(fn, output, *args, **new_kwargs) + if not _pynative_exec.check_run(fn, *args, **new_kwargs): + _pynative_exec.set_grad_flag(True) + _pynative_exec.new_graph(fn, *args, **new_kwargs) + output = fn(*args, **new_kwargs) + _pynative_exec.end_graph(fn, output, *args, **new_kwargs) else: - if fn.already_run and not fn.requires_grad: - raise ValueError("obj must set_grad.") - if not fn.already_run: - self.need_forward = True - if self.need_forward: + # Check if fn have run already + if not _pynative_exec.check_run(fn, *args, **new_kwargs): fn.set_grad() fn(*args, **new_kwargs) - fn.already_run = False def __call__(self, fn, weights=None): grad_ = GradOperation(self.get_all, self.get_by_list, self.sens_param) @@ -367,7 +361,6 @@ class GradOperation(GradOperation_): def after_grad(*args, **kwargs): if _pynative_exec.check_graph(fn, *args, **kwargs): print("Another grad step is running") - fn.already_run = False self._pynative_forward_run(args, kwargs, fn) _pynative_exec.grad(grad_, fn, weights, *args, **kwargs) out = _pynative_exec(fn, *args, **kwargs)