forked from mindspore-Ecosystem/mindspore
!19233 Fix parameter of highgrad
Merge pull request !19233 from zjun/fix_parameter
This commit is contained in:
commit
9fdd23f783
|
@ -1933,7 +1933,7 @@ void GradExecutor::InitResourceAndDfBuilder(const std::string &cell_id, const py
|
|||
}
|
||||
bprop_grad_stack_.push(std::make_pair(cell_id, false));
|
||||
} else if (top_cell()->grad_order() != grad_order_) {
|
||||
MakeNewTopGraph(cell_id, args, true);
|
||||
MakeNewTopGraph(cell_id, args, false);
|
||||
bprop_grad_stack_.push(std::make_pair(cell_id, true));
|
||||
}
|
||||
};
|
||||
|
@ -2108,11 +2108,7 @@ void GradExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const p
|
|||
void GradExecutor::CreateMakeTupleNodeForMultiOut(const std::string &cell_id, const FuncGraphPtr &curr_g,
|
||||
const py::object &out, const std::string &out_id) {
|
||||
MS_EXCEPTION_IF_NULL(curr_g);
|
||||
if (!(py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out))) {
|
||||
MS_LOG(EXCEPTION) << "The out of top cell should be tuple or list when set maketuple as output node";
|
||||
}
|
||||
auto out_tuple = out.cast<py::tuple>();
|
||||
|
||||
// get input node and value
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
ValuePtrList input_args;
|
||||
|
@ -2139,9 +2135,14 @@ void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const
|
|||
MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id;
|
||||
if (cell_stack_.empty()) {
|
||||
MS_LOG(DEBUG) << "Current cell " << cell_id << " no need to run EndGraphInner again";
|
||||
if (top_cell()->is_topest() && cell_id == top_cell()->cell_id()) {
|
||||
PopHighOrderGraphStack();
|
||||
set_grad_flag(false);
|
||||
if (cell_id == top_cell()->cell_id()) {
|
||||
if (top_cell()->is_topest()) {
|
||||
set_grad_flag(false);
|
||||
}
|
||||
auto outer_top_cell = PopHighOrderGraphStack();
|
||||
if (outer_top_cell != nullptr) {
|
||||
set_top_cell(outer_top_cell);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -2151,8 +2152,7 @@ void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const
|
|||
MS_EXCEPTION_IF_NULL(graph_info);
|
||||
if (graph_info->node_map.find(out_id) == graph_info->node_map.end()) {
|
||||
if (py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out)) {
|
||||
auto tuple_out = py::cast<py::tuple>(out);
|
||||
CreateMakeTupleNodeForMultiOut(cell_id, curr_g_, tuple_out, out_id);
|
||||
CreateMakeTupleNodeForMultiOut(cell_id, curr_g_, out, out_id);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Set ValueNode as output for graph, out id: " << out_id;
|
||||
MakeValueNode(out, out_id);
|
||||
|
@ -2205,6 +2205,7 @@ void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object
|
|||
if (custom_bprop_cell_count_ != 0) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Do grad for custom bprop";
|
||||
size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size();
|
||||
if (par_number > 0) {
|
||||
MS_LOG(EXCEPTION) << "When user defines the net bprop, there are " << par_number
|
||||
|
@ -2623,20 +2624,32 @@ void GradExecutor::MakeNestedCnode(const py::object &cell, const std::string &ce
|
|||
}
|
||||
|
||||
ValuePtrList weights_args;
|
||||
std::unordered_set<std::string> params_set;
|
||||
for (const auto &sec : second_graph_info->params) {
|
||||
if (sec.second->has_default()) {
|
||||
params_set.emplace(sec.first);
|
||||
}
|
||||
}
|
||||
auto manager = Manage({first_grad_fg}, false);
|
||||
for (const auto &fir : first_graph_info->params) {
|
||||
auto p = fir.second->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(p);
|
||||
if (!p->has_default()) {
|
||||
if (!fir.second->has_default()) {
|
||||
continue;
|
||||
}
|
||||
for (const auto &sec : second_graph_info->params) {
|
||||
MS_LOG(DEBUG) << "Param name " << fir.first << " ptr " << fir.second.get();
|
||||
if (fir.second->name() == sec.second->name()) {
|
||||
manager->Replace(fir.second, sec.second);
|
||||
inputs.emplace_back(sec.second);
|
||||
weights_args.emplace_back(sec.second->default_param());
|
||||
break;
|
||||
// No this weight param, need add to second graph
|
||||
if (!params_set.count(fir.first)) {
|
||||
SetParamNodeMapInGraphInfoMap(second_df_builder, fir.first, fir.second);
|
||||
inputs.emplace_back(fir.second);
|
||||
weights_args.emplace_back(fir.second->default_param());
|
||||
} else {
|
||||
// Need replace
|
||||
for (const auto &sec : second_graph_info->params) {
|
||||
MS_LOG(DEBUG) << "Param name " << fir.first << " ptr " << fir.second.get();
|
||||
if (sec.second->has_default() && fir.second->name() == sec.second->name()) {
|
||||
manager->Replace(fir.second, sec.second);
|
||||
inputs.emplace_back(sec.second);
|
||||
weights_args.emplace_back(sec.second->default_param());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue