fix op id issue in pynative

This commit is contained in:
kingfo 2020-09-04 14:49:14 +08:00
parent 4499d126d6
commit cfda024336
3 changed files with 28 additions and 2 deletions

View File

@ -302,6 +302,7 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
auto inst = pynative::PynativeExecutor::GetInstance();
inst->SaveOpForwardValue(input_value.second, input_value.first);
auto input_value_node = NewValueNode(input_value.first);
input_value_node->set_has_new_value(true);
manager->Replace(paras[i], input_value_node);
}
}

View File

@ -632,6 +632,9 @@ ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) {
MS_LOG(DEBUG) << "Get: " << op_exec_info->op_name << "(" << op << "), " << iter->second;
return iter->second;
}
if (!first_grad_step_) {
++op_id_map_[id];
}
return nullptr;
}
@ -979,7 +982,10 @@ void ClearPyNativeSession() { session = nullptr; }
PynativeExecutor::~PynativeExecutor() { ClearRes(); }
PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
PynativeExecutor::PynativeExecutor() {
grad_flag_ = false;
first_grad_step_ = false;
}
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
auto cell_id = GetCellId(cell, args);
@ -1000,6 +1006,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
cell_resource_map_[cell_id] = resource_;
df_builder_ = std::make_shared<FuncGraph>();
MS_LOG(DEBUG) << "First new graph" << top_g_.get();
first_grad_step_ = true;
top_graph_cells_.insert(cell_id);
Pushp();
} else {
Pushp();
@ -1181,7 +1189,9 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString();
resource_->manager()->AddFuncGraph(curr_g_);
// custom bprop debug
bool need_replace_param = false;
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
need_replace_param = true;
size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size();
if (par_number > 0) {
MS_LOG(EXCEPTION) << "When user defines the net bprop, there are " << par_number
@ -1195,6 +1205,15 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
}
}
auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_);
if (need_replace_param) {
auto params = newfg->parameters();
auto manager = Manage({newfg}, false);
for (size_t i = 0; i < params.size(); i++) {
ValuePtr value = PyAttrValue(args[i]);
auto v_node = NewValueNode(value);
manager->Replace(params[i], v_node);
}
}
graph_info_map_.erase(curr_g_);
if (curr_g_ != top_g_) {
Popp();
@ -1355,6 +1374,9 @@ void PynativeExecutor::Clear(const std::string &flag) {
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
}
ConfigManager::GetInstance().ResetIterNum();
if (top_graph_cells_.find(flag) != top_graph_cells_.end()) {
op_forward_map_.clear();
}
return;
}
@ -1363,6 +1385,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
top_g_ = nullptr;
df_builder_ = nullptr;
curr_g_ = nullptr;
first_grad_step_ = false;
graph_info_map_.clear();
op_id_map_.clear();
obj_to_forward_id_.clear();
@ -1374,7 +1397,6 @@ void PynativeExecutor::Clean() {
MS_LOG(DEBUG) << "Clean all res";
Clear();
grad_flag_ = false;
op_forward_map_.clear();
ad::CleanRes();
pipeline::ReclaimOptimizer();
}

View File

@ -24,6 +24,7 @@
#include <unordered_map>
#include <mutex>
#include <stack>
#include <set>
#include "pybind11/pybind11.h"
#include "pybind11/numpy.h"
@ -145,6 +146,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
static ResourcePtr resource_;
static int graph_id_;
bool grad_flag_;
bool first_grad_step_;
std::unordered_map<std::string, FuncGraphPtr> graph_map_;
std::unordered_map<std::string, FuncGraphPtr> cell_graph_map_;
std::unordered_map<std::string, ResourcePtr> cell_resource_map_;
@ -158,6 +160,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
FuncGraphPtr df_builder_;
FuncGraphPtr curr_g_;
std::unordered_map<std::string, AbstractListMap> prim_abs_list_;
std::set<std::string> top_graph_cells_;
};
using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>;