forked from mindspore-Ecosystem/mindspore
fix op id issue in pynative
This commit is contained in:
parent
4499d126d6
commit
cfda024336
|
@ -302,6 +302,7 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
|
|||
auto inst = pynative::PynativeExecutor::GetInstance();
|
||||
inst->SaveOpForwardValue(input_value.second, input_value.first);
|
||||
auto input_value_node = NewValueNode(input_value.first);
|
||||
input_value_node->set_has_new_value(true);
|
||||
manager->Replace(paras[i], input_value_node);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -632,6 +632,9 @@ ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) {
|
|||
MS_LOG(DEBUG) << "Get: " << op_exec_info->op_name << "(" << op << "), " << iter->second;
|
||||
return iter->second;
|
||||
}
|
||||
if (!first_grad_step_) {
|
||||
++op_id_map_[id];
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -979,7 +982,10 @@ void ClearPyNativeSession() { session = nullptr; }
|
|||
|
||||
PynativeExecutor::~PynativeExecutor() { ClearRes(); }
|
||||
|
||||
PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
|
||||
PynativeExecutor::PynativeExecutor() {
|
||||
grad_flag_ = false;
|
||||
first_grad_step_ = false;
|
||||
}
|
||||
|
||||
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
|
||||
auto cell_id = GetCellId(cell, args);
|
||||
|
@ -1000,6 +1006,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
|
|||
cell_resource_map_[cell_id] = resource_;
|
||||
df_builder_ = std::make_shared<FuncGraph>();
|
||||
MS_LOG(DEBUG) << "First new graph" << top_g_.get();
|
||||
first_grad_step_ = true;
|
||||
top_graph_cells_.insert(cell_id);
|
||||
Pushp();
|
||||
} else {
|
||||
Pushp();
|
||||
|
@ -1181,7 +1189,9 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
|
|||
MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString();
|
||||
resource_->manager()->AddFuncGraph(curr_g_);
|
||||
// custom bprop debug
|
||||
bool need_replace_param = false;
|
||||
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
|
||||
need_replace_param = true;
|
||||
size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size();
|
||||
if (par_number > 0) {
|
||||
MS_LOG(EXCEPTION) << "When user defines the net bprop, there are " << par_number
|
||||
|
@ -1195,6 +1205,15 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
|
|||
}
|
||||
}
|
||||
auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_);
|
||||
if (need_replace_param) {
|
||||
auto params = newfg->parameters();
|
||||
auto manager = Manage({newfg}, false);
|
||||
for (size_t i = 0; i < params.size(); i++) {
|
||||
ValuePtr value = PyAttrValue(args[i]);
|
||||
auto v_node = NewValueNode(value);
|
||||
manager->Replace(params[i], v_node);
|
||||
}
|
||||
}
|
||||
graph_info_map_.erase(curr_g_);
|
||||
if (curr_g_ != top_g_) {
|
||||
Popp();
|
||||
|
@ -1355,6 +1374,9 @@ void PynativeExecutor::Clear(const std::string &flag) {
|
|||
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
|
||||
}
|
||||
ConfigManager::GetInstance().ResetIterNum();
|
||||
if (top_graph_cells_.find(flag) != top_graph_cells_.end()) {
|
||||
op_forward_map_.clear();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1363,6 +1385,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
|
|||
top_g_ = nullptr;
|
||||
df_builder_ = nullptr;
|
||||
curr_g_ = nullptr;
|
||||
first_grad_step_ = false;
|
||||
graph_info_map_.clear();
|
||||
op_id_map_.clear();
|
||||
obj_to_forward_id_.clear();
|
||||
|
@ -1374,7 +1397,6 @@ void PynativeExecutor::Clean() {
|
|||
MS_LOG(DEBUG) << "Clean all res";
|
||||
Clear();
|
||||
grad_flag_ = false;
|
||||
op_forward_map_.clear();
|
||||
ad::CleanRes();
|
||||
pipeline::ReclaimOptimizer();
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <unordered_map>
|
||||
#include <mutex>
|
||||
#include <stack>
|
||||
#include <set>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/numpy.h"
|
||||
|
@ -145,6 +146,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
static ResourcePtr resource_;
|
||||
static int graph_id_;
|
||||
bool grad_flag_;
|
||||
bool first_grad_step_;
|
||||
std::unordered_map<std::string, FuncGraphPtr> graph_map_;
|
||||
std::unordered_map<std::string, FuncGraphPtr> cell_graph_map_;
|
||||
std::unordered_map<std::string, ResourcePtr> cell_resource_map_;
|
||||
|
@ -158,6 +160,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
FuncGraphPtr df_builder_;
|
||||
FuncGraphPtr curr_g_;
|
||||
std::unordered_map<std::string, AbstractListMap> prim_abs_list_;
|
||||
std::set<std::string> top_graph_cells_;
|
||||
};
|
||||
|
||||
using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>;
|
||||
|
|
Loading…
Reference in New Issue