fix bug of pynative in back gradient graph mode's transdata inserting when the node is the output of the graph

This commit is contained in:
William Lian 2020-09-19 11:15:12 +08:00
parent a3484c1c0c
commit 25fcd0488e
3 changed files with 13 additions and 4 deletions

View File

@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "utils/utils.h" #include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/ascend/ascend_helper.h" #include "backend/optimizer/ascend/ascend_helper.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
@ -30,12 +31,12 @@ const BaseRef InsertTransOp::DefinePattern() const {
return VectorRef({V, Xs}); return VectorRef({V, Xs});
} }
bool IsGraphOutput(const AnfNodePtr &node, const std::vector<AnfNodePtr> &outputs) { bool IsGraphOutput(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
auto outputs = AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem});
auto iter = std::find(outputs.begin(), outputs.end(), node); auto iter = std::find(outputs.begin(), outputs.end(), node);
if (iter != outputs.end()) { if (iter != outputs.end() && GetRealNodeNum(func_graph, node) == 1) {
return true; return true;
} }
return false; return false;
} }
@ -55,7 +56,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode && if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK)) { !ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK)) {
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) { if (IsGraphOutput(node, func_graph)) {
return new_node; return new_node;
} }
} }

View File

@ -485,6 +485,12 @@ void RemoveNopNode(session::KernelGraph *const graph) {
} }
} }
size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node) {
auto out_list = GetRealNodeUsedList(graph, node);
MS_EXCEPTION_IF_NULL(out_list);
return out_list->size();
}
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph, std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
const AnfNodePtr &node) { const AnfNodePtr &node) {
auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>(); auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();

View File

@ -172,6 +172,8 @@ bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node);
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph, std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
const AnfNodePtr &node); const AnfNodePtr &node);
size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node);
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph, std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
const AnfNodePtr &node, const AnfNodePtr &node,
size_t output_index); size_t output_index);