!21577 fix event insertion bug

Merge pull request !21577 from fix-stream-1.4
This commit is contained in:
zhangzhenghai 2021-08-10 06:57:38 +00:00 committed by Gitee
commit fb2b4f0617
2 changed files with 26 additions and 13 deletions

View File

@ -1992,6 +1992,28 @@ CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const NotNull<KernelGraphPtr>
return recv_node_ptr;
}
bool AscendStreamAssign::IsNopNodeTarget(const AnfNodePtr &nop_node, const CNodePtr &target_node,
const CNodePtr &cur_node, bool exclude_hcom) {
MS_EXCEPTION_IF_NULL(nop_node);
auto cnode = nop_node->cast<CNodePtr>();
auto new_inputs = cnode->inputs();
for (size_t i = 1; i < new_inputs.size(); i++) {
if (opt::IsNopNode(new_inputs[i])) {
if (IsNopNodeTarget(new_inputs[i], target_node, cur_node, exclude_hcom)) {
return true;
}
} else {
auto new_real_input = AnfAlgo::VisitKernel(new_inputs[i], 0);
if (target_node == new_real_input.first) {
if (!(exclude_hcom && IsHcom(cur_node))) {
return true;
}
}
}
}
return false;
}
vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::iterator begin,
vector<CNodePtr>::iterator end, const CNodePtr &node,
bool exclude_hcom) {
@ -2000,18 +2022,8 @@ vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::it
for (size_t i = 1; i < inputs.size(); i++) {
auto input = inputs[i];
if (opt::IsNopNode(input)) {
CNodePtr cnode = input->cast<CNodePtr>();
auto new_inputs = cnode->inputs();
for (size_t j = 1; j < new_inputs.size(); j++) {
auto new_real_input = AnfAlgo::VisitKernel(new_inputs[j], 0);
// find target node except hcom op. insert event for hcom in:InsertEventHcomDependCommonBak function
// only insert one time
if (node == new_real_input.first) {
if (!(exclude_hcom && IsHcom(*begin))) {
MS_LOG(DEBUG) << "Nop node find target op[" << (*begin)->DebugString() << "]";
return begin;
}
}
if (IsNopNodeTarget(input, node, *begin, exclude_hcom)) {
return begin;
}
} else {
auto real_input = AnfAlgo::VisitKernel(input, 0);

View File

@ -175,7 +175,8 @@ class AscendStreamAssign {
uint32_t GetIndexByKey(const NotNull<KernelGraphPtr> &graph_ptr, const CNodeKey &key);
uint32_t GetIndependentStreamSwitchStreamId(const NotNull<KernelGraphPtr> &graph_ptr);
void GetIndependentMaxTarget(const NotNull<KernelGraphPtr> &graph_ptr);
bool IsNopNodeTarget(const AnfNodePtr &nop_node, const CNodePtr &target_node, const CNodePtr &cur_node,
bool exclude_hcom);
bool IsTaskSink();
bool IsHcom(const CNodePtr &cur_cnode_ptr);
bool IsIndependentNode(const CNodePtr &node_ptr);