!10348 Fix bprop describe

From: @zjun3021
Reviewed-by: @chujinjin,@zhoufeng54
Signed-off-by: @chujinjin
This commit is contained in:
mindspore-ci-bot 2020-12-23 13:51:12 +08:00 committed by Gitee
commit 5dfcbc3558
4 changed files with 5 additions and 4 deletions

View File

@ -52,6 +52,7 @@ FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
outputs.push_back(NewValueNode(fake_bprop));
py::object code_obj = py::getattr(bprop_func, "__code__");
// Three parameters self, out and dout need to be excluded
size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3;
for (size_t i = 0; i < inputs_num; ++i) {
auto param = bprop_graph->add_parameter();

View File

@ -2488,7 +2488,7 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg
inputs.emplace_back(GetInput(args[i], false));
}
if (newfg->parameters().size() > inputs_size) {
SetNestedWeigthsParam(newfg, cell_id, &inputs);
SetNestedWeightsParam(newfg, cell_id, &inputs);
}
auto out_id = GetId(out);
auto cnode = graph_prev->NewCNode(inputs);
@ -2497,7 +2497,7 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString(4);
}
void PynativeExecutor::SetNestedWeigthsParam(const FuncGraphPtr &newfg, const std::string &cell_id,
void PynativeExecutor::SetNestedWeightsParam(const FuncGraphPtr &newfg, const std::string &cell_id,
std::vector<AnfNodePtr> *inputs) {
FuncGraphPtr forward_graph = nullptr;
auto ic = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),

View File

@ -240,7 +240,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id);
void MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource,
const py::object &out, bool has_sens);
void SetNestedWeigthsParam(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector<AnfNodePtr> *inputs);
void SetNestedWeightsParam(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector<AnfNodePtr> *inputs);
bool MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id);
// Hold graph(forward and grad) info

View File

@ -49,7 +49,7 @@ class Cell(Cell_):
The bprop implementation will receive a Tensor `dout` containing the gradient of the loss w.r.t.
the output, and a Tensor `out` containing the forward result. The bprop needs to compute the
gradient of the loss w.r.t. the inputs, gradient of the loss w.r.t. Parameter variables are not supported
currently.
currently. The bprop method must contain the self parameter.
Args:
auto_prefix (bool): Recursively generate namespaces. Default: True.