forked from mindspore-Ecosystem/mindspore
Fix fusion condition of transpose and reshape
This commit is contained in:
parent
7f80d02807
commit
94818cf255
|
@ -50,9 +50,9 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph,
|
||||||
MS_EXCEPTION_IF_NULL(reshape_cnode);
|
MS_EXCEPTION_IF_NULL(reshape_cnode);
|
||||||
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum);
|
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum);
|
||||||
MS_EXCEPTION_IF_NULL(transpose_cnode);
|
MS_EXCEPTION_IF_NULL(transpose_cnode);
|
||||||
std::vector<size_t> reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0);
|
std::vector<size_t> reshape_output0_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0);
|
||||||
std::vector<size_t> transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0);
|
std::vector<size_t> transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0);
|
||||||
if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) {
|
if (!CheckShapeDimInfo(reshape_output0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto prim = std::make_shared<Primitive>(kConfusionTransposeDOpName);
|
auto prim = std::make_shared<Primitive>(kConfusionTransposeDOpName);
|
||||||
|
|
|
@ -38,7 +38,7 @@ def test_transpose_reshape_fusion(tag):
|
||||||
@fns
|
@fns
|
||||||
def before(x):
|
def before(x):
|
||||||
transpose = Transpose(x, (1, 0, 2, 3))
|
transpose = Transpose(x, (1, 0, 2, 3))
|
||||||
reshape = Reshape(transpose, (2, 4, 8, 16))
|
reshape = Reshape(transpose, (2, 2, 16, 16))
|
||||||
return reshape
|
return reshape
|
||||||
|
|
||||||
@fns
|
@fns
|
||||||
|
|
Loading…
Reference in New Issue