diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc index 2f982e0413e..250f86d9b1e 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc @@ -50,9 +50,9 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph, MS_EXCEPTION_IF_NULL(reshape_cnode); auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum); MS_EXCEPTION_IF_NULL(transpose_cnode); - std::vector reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0); + std::vector reshape_output0_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); std::vector transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0); - if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) { + if (!CheckShapeDimInfo(reshape_output0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) { return nullptr; } auto prim = std::make_shared(kConfusionTransposeDOpName); diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/transpose_reshape_fusion_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/transpose_reshape_fusion_test.py index 405bcf0cec1..2772b11eea0 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/transpose_reshape_fusion_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/transpose_reshape_fusion_test.py @@ -38,7 +38,7 @@ def test_transpose_reshape_fusion(tag): @fns def before(x): transpose = Transpose(x, (1, 0, 2, 3)) - reshape = Reshape(transpose, (2, 4, 8, 16)) + reshape = Reshape(transpose, (2, 2, 16, 16)) return reshape @fns