!4573 Fix do fission when UnsegmentSum's input0 shape size is 1

Merge pull request !4573 from huanghui/unsorted-segment-sum-fission-pass
This commit is contained in:
mindspore-ci-bot 2020-08-17 16:08:15 +08:00 committed by Gitee
commit 50e78118fb
1 changed files with 4 additions and 0 deletions

View File

@ -94,6 +94,10 @@ const AnfNodePtr UnsortSegmentSumFission::Process(const FuncGraphPtr &graph, con
return nullptr;
}
auto input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
if (input0_shape.size() < 2) {
MS_LOG(INFO) << "Input0's shape size less than 2, not optimize";
return nullptr;
}
if (input0_shape[input0_shape.size() - 1] != 1) {
MS_LOG(INFO) << "UnsortedSegmentSum is not need fission. The last value of input0's shape is "
<< input0_shape[input0_shape.size() - 1];