fix bug of pynative in back gradient graph mode's transdata inserting when the node is the output of the graph

This commit is contained in:
William Lian 2020-09-19 11:15:12 +08:00
parent a3484c1c0c
commit 25fcd0488e
3 changed files with 13 additions and 4 deletions

View File

@ -18,6 +18,7 @@
#include <memory>
#include <vector>
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/ascend/ascend_helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/ms_context.h"
@ -30,12 +31,12 @@ const BaseRef InsertTransOp::DefinePattern() const {
return VectorRef({V, Xs});
}
bool IsGraphOutput(const AnfNodePtr &node, const std::vector<AnfNodePtr> &outputs) {
bool IsGraphOutput(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
auto outputs = AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem});
auto iter = std::find(outputs.begin(), outputs.end(), node);
if (iter != outputs.end()) {
if (iter != outputs.end() && GetRealNodeNum(func_graph, node) == 1) {
return true;
}
return false;
}
@ -55,7 +56,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK)) {
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) {
if (IsGraphOutput(node, func_graph)) {
return new_node;
}
}

View File

@ -485,6 +485,12 @@ void RemoveNopNode(session::KernelGraph *const graph) {
}
}
size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node) {
auto out_list = GetRealNodeUsedList(graph, node);
MS_EXCEPTION_IF_NULL(out_list);
return out_list->size();
}
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
const AnfNodePtr &node) {
auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();

View File

@ -172,6 +172,8 @@ bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node);
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
const AnfNodePtr &node);
size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node);
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
const AnfNodePtr &node,
size_t output_index);