forked from mindspore-Ecosystem/mindspore
!10330 Fix pynative paramters second derivative
From: @zjun3021 Reviewed-by: @kisnwang,@chujinjin Signed-off-by: @chujinjin
This commit is contained in:
commit
695cdbbe69
|
@ -1439,6 +1439,12 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &
|
|||
return cell_id;
|
||||
}
|
||||
|
||||
void PynativeExecutor::DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph) {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
DumpIR(filename, graph);
|
||||
}
|
||||
}
|
||||
|
||||
bool PynativeExecutor::IsNotNestedGrad() const {
|
||||
MS_LOG(DEBUG) << "Grad nested count is " << grad_order_;
|
||||
return grad_order_ <= 1;
|
||||
|
@ -1851,6 +1857,7 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string
|
|||
curr_g_->set_output(output_node);
|
||||
MS_LOG(DEBUG) << "Current graph " << curr_g_->output()->DebugString();
|
||||
if (EndBpropGraph(cell_id)) {
|
||||
MS_LOG(DEBUG) << "Get bprop function cell";
|
||||
return;
|
||||
}
|
||||
auto resource = GetResource(cell_id);
|
||||
|
@ -1875,13 +1882,9 @@ void PynativeExecutor::EndGraphByOutId(const py::object &cell, const std::string
|
|||
SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode);
|
||||
SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode);
|
||||
} else {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
DumpIR("before_resolve.ir", newfg);
|
||||
}
|
||||
DumpGraphIR("before_resolve.ir", newfg);
|
||||
parse::ResolveFuncGraph(newfg, resource);
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
DumpIR("after_resolve.ir", newfg);
|
||||
}
|
||||
DumpGraphIR("after_resolve.ir", newfg);
|
||||
resource->set_func_graph(newfg);
|
||||
PopGraphStack();
|
||||
}
|
||||
|
@ -1907,10 +1910,12 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt
|
|||
if (it != cell_graph_list_.end()) {
|
||||
it->is_grad = is_grad;
|
||||
it->fg = g;
|
||||
MS_LOG(DEBUG) << "Update bprop bg";
|
||||
MS_LOG(DEBUG) << "Update bprop bg cell id " << cell_id;
|
||||
} else {
|
||||
py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME);
|
||||
auto cell_info = CellInfo(false, true, g, cell_id, GetId(bprop_func));
|
||||
auto bprop_func_cell_id = GetId(bprop_func);
|
||||
MS_LOG(DEBUG) << "Add new bprop cell_id " << cell_id << " bprop func cell id " << bprop_func_cell_id;
|
||||
auto cell_info = CellInfo(false, true, g, cell_id, bprop_func_cell_id);
|
||||
cell_graph_list_.insert(cell_graph_list_.begin(), cell_info);
|
||||
}
|
||||
return;
|
||||
|
@ -1959,13 +1964,11 @@ FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncG
|
|||
(void)bprop_graph->transforms().emplace(std::make_pair("primal", FuncGraphTransform(g)));
|
||||
}
|
||||
}
|
||||
// Obtain grad graph
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
DumpIR("fg.ir", g);
|
||||
}
|
||||
DumpGraphIR("fg.ir", g);
|
||||
auto is_top = IsTopGraph(cell_id);
|
||||
MS_LOG(DEBUG) << "Grad top cell " << is_top;
|
||||
set_need_replace_forward(IsNotNestedGrad());
|
||||
// Obtain grad graph
|
||||
auto newfg = ad::Grad(g, r, is_top);
|
||||
|
||||
if (is_custom_bprop) {
|
||||
|
@ -2039,11 +2042,9 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
|
|||
auto args_spec = GetArgsSpec(args, df_builder);
|
||||
resource->set_args_spec(args_spec);
|
||||
// Get real grad graph
|
||||
DumpGraphIR("before_grad.ir", resource->func_graph());
|
||||
GradGraph(resource->func_graph(), grad, w_args, size, cell_id);
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
DumpIR("before_grad.ir", resource->func_graph());
|
||||
DumpIR("after_grad.ir", df_builder);
|
||||
}
|
||||
DumpGraphIR("after_grad.ir", df_builder);
|
||||
resource->set_func_graph(df_builder);
|
||||
resource->manager()->KeepRoots({df_builder});
|
||||
resource->results()[pipeline::kBackend] = compile::CreateBackend();
|
||||
|
@ -2127,30 +2128,35 @@ void PynativeExecutor::SetNestedTopGraph(const py::object &cell, const py::args
|
|||
}
|
||||
MS_EXCEPTION_IF_NULL(forward_graph);
|
||||
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
DumpIR("nested_bprop.ir", forward_graph);
|
||||
}
|
||||
DumpGraphIR("nested_bprop.ir", forward_graph);
|
||||
// Custom bprop get backward graph(before opt), which use like other forward graph
|
||||
curr_g_ = forward_graph;
|
||||
resource->set_func_graph(forward_graph);
|
||||
return;
|
||||
}
|
||||
|
||||
// Copy weights
|
||||
std::vector<AnfNodePtr> weights_params{};
|
||||
for (const auto &it : graph_info_map_.at(forward_graph).params) {
|
||||
if (it.second->has_default()) {
|
||||
weights_params.emplace_back(it.second);
|
||||
graph_info_map_.at(df_builder).params.emplace(it.first, it.second);
|
||||
SetNodeMapInGraphInfoMap(df_builder, it.first, it.second);
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Get weights params size " << weights_params.size();
|
||||
df_builder->set_parameters(weights_params);
|
||||
// Copy weights parameters
|
||||
resource->manager()->AddFuncGraph(forward_graph);
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
DumpIR("nested_fg.ir", forward_graph);
|
||||
auto manager = Manage({forward_graph}, false);
|
||||
for (const auto &it : graph_info_map_.at(forward_graph).params) {
|
||||
if (!it.second->has_default()) {
|
||||
continue;
|
||||
}
|
||||
auto new_param = df_builder->add_parameter();
|
||||
new_param->set_abstract(it.second->abstract());
|
||||
new_param->set_name(it.second->name());
|
||||
new_param->set_default_param(it.second->default_param());
|
||||
ScopePtr scope = (it.second->scope() != kDefaultScope) ? it.second->scope() : kDefaultScope;
|
||||
new_param->set_scope(scope);
|
||||
manager->Replace(it.second, new_param);
|
||||
replace_weights_map_[forward_graph].emplace_back(std::make_pair(it.second, new_param));
|
||||
MS_LOG(DEBUG) << "Old param ptr " << it.second.get() << " name " << it.second->name();
|
||||
|
||||
graph_info_map_.at(df_builder).params[it.first] = new_param;
|
||||
SetParamNodeMapInGraphInfoMap(df_builder, it.first, new_param);
|
||||
SetNodeMapInGraphInfoMap(df_builder, it.first, new_param);
|
||||
}
|
||||
DumpGraphIR("nested_fg.ir", forward_graph);
|
||||
set_need_replace_forward(false);
|
||||
auto newfg = MakeGradGraph(cell, forward_graph, resource, cell_id, args);
|
||||
resource->set_func_graph(newfg);
|
||||
|
@ -2396,15 +2402,18 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg
|
|||
MS_LOG(DEBUG) << "Get pre graph ptr " << graph_prev.get();
|
||||
auto newfg = resource->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(newfg);
|
||||
auto size = args.size();
|
||||
auto inputs_size = args.size();
|
||||
if (has_sens) {
|
||||
size -= 1;
|
||||
inputs_size -= 1;
|
||||
}
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.emplace_back(NewValueNode(newfg));
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
for (size_t i = 0; i < inputs_size; ++i) {
|
||||
inputs.emplace_back(GetInput(args[i], false));
|
||||
}
|
||||
if (newfg->parameters().size() > inputs_size) {
|
||||
SetNestedWeigthsParam(newfg, cell_id, &inputs);
|
||||
}
|
||||
auto out_id = GetId(out);
|
||||
auto cnode = graph_prev->NewCNode(inputs);
|
||||
SetTupleArgsToGraphInfoMap(graph_prev, out, cnode);
|
||||
|
@ -2412,6 +2421,38 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg
|
|||
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString(4);
|
||||
}
|
||||
|
||||
void PynativeExecutor::SetNestedWeigthsParam(const FuncGraphPtr &newfg, const std::string &cell_id,
|
||||
std::vector<AnfNodePtr> *inputs) {
|
||||
FuncGraphPtr forward_graph = nullptr;
|
||||
auto ic = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),
|
||||
[&cell_id](const CellInfo &value) { return value.cell_id == cell_id; });
|
||||
if (ic != cell_graph_list_.end()) {
|
||||
forward_graph = ic->fg;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(forward_graph);
|
||||
auto params = newfg->parameters();
|
||||
auto manage = Manage({newfg}, false);
|
||||
for (const auto &it : params) {
|
||||
auto param = it->cast<ParameterPtr>();
|
||||
if (!param->has_default()) {
|
||||
continue;
|
||||
}
|
||||
auto ir = replace_weights_map_.find(forward_graph);
|
||||
if (ir == replace_weights_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Not find forward_graph in repalce weigths map";
|
||||
}
|
||||
for (const auto &ip : ir->second) {
|
||||
MS_LOG(DEBUG) << "Get param name " << param->name() << " cache name " << ip.second->name();
|
||||
if (ip.second->name() == param->name()) {
|
||||
manage->Replace(param, ip.first);
|
||||
inputs->emplace_back(ip.first);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
replace_weights_map_.erase(forward_graph);
|
||||
}
|
||||
|
||||
void PynativeExecutor::Clear(const std::string &cell_id) {
|
||||
if (cell_id.empty()) {
|
||||
Clean();
|
||||
|
@ -2461,6 +2502,7 @@ void PynativeExecutor::ClearRes() {
|
|||
|
||||
graph_info_map_.clear();
|
||||
cell_sw_map_.clear();
|
||||
replace_weights_map_.clear();
|
||||
cell_graph_list_.clear();
|
||||
top_cell_list_.clear();
|
||||
op_index_map_.clear();
|
||||
|
|
|
@ -60,7 +60,7 @@ void ClearPyNativeSession();
|
|||
struct GraphInfo {
|
||||
std::string cell_id;
|
||||
AnfNodePtr output;
|
||||
std::unordered_map<std::string, ParameterPtr> params; // hold input parameters and cell weigths
|
||||
OrderedMap<std::string, ParameterPtr> params; // hold input parameters and cell weigths
|
||||
std::unordered_map<std::string, std::pair<AnfNodePtr, std::vector<int64_t>>> node_map;
|
||||
std::vector<std::string> objects;
|
||||
GraphInfo() = default;
|
||||
|
@ -210,6 +210,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
|
||||
bool need_cloned = false, bool is_grad = false);
|
||||
void ClearResidualRes(const std::string &cell_id);
|
||||
void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph);
|
||||
void NewGraphInner(const py::object &cell, const py::args &args);
|
||||
void MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g);
|
||||
void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args);
|
||||
|
@ -233,6 +234,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
void SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id);
|
||||
void MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource,
|
||||
const py::object &out, bool has_sens);
|
||||
void SetNestedWeigthsParam(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector<AnfNodePtr> *inputs);
|
||||
bool MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id);
|
||||
|
||||
// Hold graph(forward and grad) info
|
||||
|
@ -242,7 +244,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
|
||||
bool is_param = false);
|
||||
void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr ¶m) {
|
||||
graph_info_map_[g].params.emplace(std::make_pair(id, param));
|
||||
graph_info_map_[g].params[id] = param;
|
||||
}
|
||||
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
|
||||
int64_t index = -1) {
|
||||
|
@ -269,15 +271,16 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
// Records forwrad graph, the bottom is top graph
|
||||
std::stack<FuncGraphPtr> graph_stack_;
|
||||
|
||||
// Use vector for keep order
|
||||
std::vector<CellInfo> cell_graph_list_;
|
||||
std::vector<TopCellInfo> top_cell_list_;
|
||||
std::unordered_set<std::string> cell_input_args_;
|
||||
std::unordered_map<std::string, bool> cell_dynamic_map_;
|
||||
// Record all info for all cells
|
||||
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
|
||||
// Use vector for keep order
|
||||
std::vector<CellInfo> cell_graph_list_;
|
||||
std::vector<TopCellInfo> top_cell_list_;
|
||||
// key: cell_id, value: (send_id, weighs_id), cache for sens and weight change
|
||||
std::unordered_map<std::string, std::pair<std::string, std::string>> cell_sw_map_;
|
||||
std::unordered_map<FuncGraphPtr, std::vector<std::pair<ParameterPtr, ParameterPtr>>> replace_weights_map_;
|
||||
|
||||
// Used for runop and replace forward result of grad graph
|
||||
std::unordered_map<std::string, size_t> op_index_map_;
|
||||
|
|
Loading…
Reference in New Issue