forked from mindspore-Ecosystem/mindspore
!2840 Update fusion condition of reshape and transpose
Merge pull request !2840 from YuJianfeng/master
This commit is contained in:
commit
7d5db0e99b
|
@ -51,8 +51,8 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph,
|
||||||
auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputNum);
|
auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputNum);
|
||||||
MS_EXCEPTION_IF_NULL(reshape_cnode);
|
MS_EXCEPTION_IF_NULL(reshape_cnode);
|
||||||
std::vector<size_t> reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0);
|
std::vector<size_t> reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0);
|
||||||
std::vector<size_t> transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0);
|
std::vector<size_t> transpose_output0_shape = AnfAlgo::GetOutputInferShape(transpose_cnode, 0);
|
||||||
if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) {
|
if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_output0_shape)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto prim = std::make_shared<Primitive>(kConfusionTransposeDOpName);
|
auto prim = std::make_shared<Primitive>(kConfusionTransposeDOpName);
|
||||||
|
|
Loading…
Reference in New Issue