forked from mindspore-Ecosystem/mindspore
fix bug of pynative in back gradient graph mode's transdata inserting when the node is the output of the graph
This commit is contained in:
parent
a3484c1c0c
commit
25fcd0488e
|
@ -18,6 +18,7 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/optimizer/ascend/ascend_helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
@ -30,12 +31,12 @@ const BaseRef InsertTransOp::DefinePattern() const {
|
|||
return VectorRef({V, Xs});
|
||||
}
|
||||
|
||||
bool IsGraphOutput(const AnfNodePtr &node, const std::vector<AnfNodePtr> &outputs) {
|
||||
bool IsGraphOutput(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
|
||||
auto outputs = AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem});
|
||||
auto iter = std::find(outputs.begin(), outputs.end(), node);
|
||||
if (iter != outputs.end()) {
|
||||
if (iter != outputs.end() && GetRealNodeNum(func_graph, node) == 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -55,7 +56,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
|
|||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
|
||||
!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK)) {
|
||||
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) {
|
||||
if (IsGraphOutput(node, func_graph)) {
|
||||
return new_node;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -485,6 +485,12 @@ void RemoveNopNode(session::KernelGraph *const graph) {
|
|||
}
|
||||
}
|
||||
|
||||
size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||
auto out_list = GetRealNodeUsedList(graph, node);
|
||||
MS_EXCEPTION_IF_NULL(out_list);
|
||||
return out_list->size();
|
||||
}
|
||||
|
||||
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
|
||||
const AnfNodePtr &node) {
|
||||
auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
|
||||
|
|
|
@ -172,6 +172,8 @@ bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node);
|
|||
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
|
||||
const AnfNodePtr &node);
|
||||
|
||||
size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node);
|
||||
|
||||
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
|
||||
const AnfNodePtr &node,
|
||||
size_t output_index);
|
||||
|
|
Loading…
Reference in New Issue