forked from mindspore-Ecosystem/mindspore
!1811 handle control-depend with virtual node
Merge pull request !1811 from caifubi/handle-contrl-depend-with-virtual-node
This commit is contained in:
commit
8d49de00e8
|
@ -521,6 +521,47 @@ std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
|
|||
return output_nodes;
|
||||
}
|
||||
|
||||
// Find control_depend real input nodes.
|
||||
void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result, std::set<AnfNodePtr> *visited) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
MS_EXCEPTION_IF_NULL(visited);
|
||||
if (visited->find(anf_node) != visited->end()) {
|
||||
MS_LOG(WARNING) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited";
|
||||
return;
|
||||
}
|
||||
visited->insert(anf_node);
|
||||
if (AnfAlgo::IsRealKernel(anf_node)) {
|
||||
result->emplace_back(anf_node);
|
||||
return;
|
||||
}
|
||||
if (!anf_node->isa<CNode>()) {
|
||||
return;
|
||||
}
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().empty()) {
|
||||
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString();
|
||||
}
|
||||
auto input0 = cnode->input(0);
|
||||
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
GetAllFatherRealNode(cnode->input(i), result, visited);
|
||||
}
|
||||
} else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
|
||||
if (cnode->inputs().size() != kTupleGetItemInputSize) {
|
||||
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
|
||||
}
|
||||
GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited);
|
||||
} else if (IsPrimitive(input0, prim::kPrimDepend)) {
|
||||
if (cnode->inputs().size() != kDependInputSize) {
|
||||
MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!";
|
||||
}
|
||||
GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited);
|
||||
GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited);
|
||||
}
|
||||
}
|
||||
|
||||
// update the depend relations of control depend
|
||||
void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) {
|
||||
for (const auto &node : depends) {
|
||||
|
@ -551,11 +592,24 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
|
|||
if (depend_node->isa<Parameter>() && depend_mode == 1) {
|
||||
depend_nodes = GetOutputNodes(depend_node);
|
||||
}
|
||||
for (auto &first_node : prior_nodes) {
|
||||
|
||||
std::vector<AnfNodePtr> real_prior_nodes;
|
||||
std::set<AnfNodePtr> prior_visited;
|
||||
for (const auto &tmp : prior_nodes) {
|
||||
GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> real_depend_nodes;
|
||||
std::set<AnfNodePtr> depend_visited;
|
||||
for (const auto &tmp : depend_nodes) {
|
||||
GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
|
||||
}
|
||||
|
||||
for (auto &first_node : real_prior_nodes) {
|
||||
if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) {
|
||||
continue;
|
||||
}
|
||||
for (auto &second_node : depend_nodes) {
|
||||
for (auto &second_node : real_depend_nodes) {
|
||||
if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -252,6 +252,7 @@ constexpr auto kControlDependMode = "depend_mode";
|
|||
// index define of depend
|
||||
constexpr auto kRealInputIndexInDepend = 1;
|
||||
constexpr auto kDependAttachNodeIndex = 2;
|
||||
constexpr auto kDependInputSize = 3;
|
||||
// format
|
||||
constexpr auto kOpFormat_DEFAULT = "DefaultFormat";
|
||||
constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0";
|
||||
|
|
Loading…
Reference in New Issue