!19233 Fix parameter of highgrad

Merge pull request !19233 from zjun/fix_parameter
This commit is contained in:
i-robot 2021-07-05 01:28:23 +00:00 committed by Gitee
commit 9fdd23f783
1 changed files with 33 additions and 20 deletions

View File

@ -1933,7 +1933,7 @@ void GradExecutor::InitResourceAndDfBuilder(const std::string &cell_id, const py
}
bprop_grad_stack_.push(std::make_pair(cell_id, false));
} else if (top_cell()->grad_order() != grad_order_) {
MakeNewTopGraph(cell_id, args, true);
MakeNewTopGraph(cell_id, args, false);
bprop_grad_stack_.push(std::make_pair(cell_id, true));
}
};
@ -2108,11 +2108,7 @@ void GradExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const p
void GradExecutor::CreateMakeTupleNodeForMultiOut(const std::string &cell_id, const FuncGraphPtr &curr_g,
const py::object &out, const std::string &out_id) {
MS_EXCEPTION_IF_NULL(curr_g);
if (!(py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out))) {
MS_LOG(EXCEPTION) << "The out of top cell should be tuple or list when set maketuple as output node";
}
auto out_tuple = out.cast<py::tuple>();
// get input node and value
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
ValuePtrList input_args;
@ -2139,9 +2135,14 @@ void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const
MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id;
if (cell_stack_.empty()) {
MS_LOG(DEBUG) << "Current cell " << cell_id << " no need to run EndGraphInner again";
if (top_cell()->is_topest() && cell_id == top_cell()->cell_id()) {
PopHighOrderGraphStack();
set_grad_flag(false);
if (cell_id == top_cell()->cell_id()) {
if (top_cell()->is_topest()) {
set_grad_flag(false);
}
auto outer_top_cell = PopHighOrderGraphStack();
if (outer_top_cell != nullptr) {
set_top_cell(outer_top_cell);
}
}
return;
}
@ -2151,8 +2152,7 @@ void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const
MS_EXCEPTION_IF_NULL(graph_info);
if (graph_info->node_map.find(out_id) == graph_info->node_map.end()) {
if (py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out)) {
auto tuple_out = py::cast<py::tuple>(out);
CreateMakeTupleNodeForMultiOut(cell_id, curr_g_, tuple_out, out_id);
CreateMakeTupleNodeForMultiOut(cell_id, curr_g_, out, out_id);
} else {
MS_LOG(DEBUG) << "Set ValueNode as output for graph, out id: " << out_id;
MakeValueNode(out, out_id);
@ -2205,6 +2205,7 @@ void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object
if (custom_bprop_cell_count_ != 0) {
return;
}
MS_LOG(DEBUG) << "Do grad for custom bprop";
size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size();
if (par_number > 0) {
MS_LOG(EXCEPTION) << "When user defines the net bprop, there are " << par_number
@ -2623,20 +2624,32 @@ void GradExecutor::MakeNestedCnode(const py::object &cell, const std::string &ce
}
ValuePtrList weights_args;
std::unordered_set<std::string> params_set;
for (const auto &sec : second_graph_info->params) {
if (sec.second->has_default()) {
params_set.emplace(sec.first);
}
}
auto manager = Manage({first_grad_fg}, false);
for (const auto &fir : first_graph_info->params) {
auto p = fir.second->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(p);
if (!p->has_default()) {
if (!fir.second->has_default()) {
continue;
}
for (const auto &sec : second_graph_info->params) {
MS_LOG(DEBUG) << "Param name " << fir.first << " ptr " << fir.second.get();
if (fir.second->name() == sec.second->name()) {
manager->Replace(fir.second, sec.second);
inputs.emplace_back(sec.second);
weights_args.emplace_back(sec.second->default_param());
break;
// No this weight param, need add to second graph
if (!params_set.count(fir.first)) {
SetParamNodeMapInGraphInfoMap(second_df_builder, fir.first, fir.second);
inputs.emplace_back(fir.second);
weights_args.emplace_back(fir.second->default_param());
} else {
// Need replace
for (const auto &sec : second_graph_info->params) {
MS_LOG(DEBUG) << "Param name " << fir.first << " ptr " << fir.second.get();
if (sec.second->has_default() && fir.second->name() == sec.second->name()) {
manager->Replace(fir.second, sec.second);
inputs.emplace_back(sec.second);
weights_args.emplace_back(sec.second->default_param());
break;
}
}
}
}