forked from mindspore-Ecosystem/mindspore
Add input shape condition for transpose_reshape fusion pass
This commit is contained in:
parent
298a784878
commit
85ff90c237
|
@ -23,6 +23,18 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
bool CheckShapeDimInfo(const std::vector<size_t> &shape) {
|
||||
if (shape.empty()) {
|
||||
return false;
|
||||
}
|
||||
if (shape.size() == 1 && shape[0] % kCubeSize != 0) {
|
||||
return false;
|
||||
}
|
||||
return !(shape.size() >= 2 && (shape[shape.size() - 1] % kCubeSize != 0 || shape[shape.size() - 2] % kCubeSize != 0));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef ReshapeTransposeFusion::DefinePattern() const {
|
||||
const auto prim_reshape = std::make_shared<Primitive>(prim::kPrimReshape->name());
|
||||
VectorRef reshape({prim_reshape, input_varptr_});
|
||||
|
@ -38,6 +50,11 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph,
|
|||
MS_EXCEPTION_IF_NULL(transpose_cnode);
|
||||
auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputNum);
|
||||
MS_EXCEPTION_IF_NULL(reshape_cnode);
|
||||
std::vector<size_t> reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0);
|
||||
std::vector<size_t> transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0);
|
||||
if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto prim = std::make_shared<Primitive>(kConfusionTransposeDOpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
|
||||
auto new_node = func_graph->NewCNode(inputs);
|
||||
|
|
|
@ -23,6 +23,18 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
bool CheckShapeDimInfo(const std::vector<size_t> &shape) {
|
||||
if (shape.empty()) {
|
||||
return false;
|
||||
}
|
||||
if (shape.size() == 1 && shape[0] % kCubeSize != 0) {
|
||||
return false;
|
||||
}
|
||||
return !(shape.size() >= 2 && (shape[shape.size() - 1] % kCubeSize != 0 || shape[shape.size() - 2] % kCubeSize != 0));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef TransposeReshapeFusion::DefinePattern() const {
|
||||
const auto prim_reshape = std::make_shared<Primitive>(prim::kPrimReshape->name());
|
||||
VectorRef transpose({prim::kPrimTranspose, input_varptr_});
|
||||
|
@ -38,6 +50,11 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph,
|
|||
MS_EXCEPTION_IF_NULL(reshape_cnode);
|
||||
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum);
|
||||
MS_EXCEPTION_IF_NULL(transpose_cnode);
|
||||
std::vector<size_t> reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0);
|
||||
std::vector<size_t> transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0);
|
||||
if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto prim = std::make_shared<Primitive>(kConfusionTransposeDOpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
|
||||
auto new_node = func_graph->NewCNode(inputs);
|
||||
|
|
|
@ -39,7 +39,7 @@ TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_fusion) {
|
|||
* return transpose
|
||||
*/
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "before");
|
||||
std::vector<int> shp{2, 4, 8, 16};
|
||||
std::vector<int> shp{2, 2, 16, 16};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list{x_abstract};
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
|
@ -59,5 +59,26 @@ TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_fusion) {
|
|||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_no_fusion) {
|
||||
/*
|
||||
* def before(input0, input1):
|
||||
* reshape = Reshape(input0, input1)
|
||||
* transpose = Transpose(reshape)
|
||||
* return transpose
|
||||
*/
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "before");
|
||||
std::vector<int> shp{2, 4, 8, 16};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list{x_abstract};
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ReshapeTransposeFusion>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
||||
EXPECT_TRUE(CheckEqualGraph(kg, new_graph));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -39,7 +39,7 @@ TEST_F(TestHWTransposeReshapeFusion, test_transpose_reshape_fusion) {
|
|||
* return transpose
|
||||
*/
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_reshape_fusion", "before");
|
||||
std::vector<int> shp{2, 4, 8, 16};
|
||||
std::vector<int> shp{2, 2, 16, 16};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list{x_abstract};
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
|
@ -61,5 +61,26 @@ TEST_F(TestHWTransposeReshapeFusion, test_transpose_reshape_fusion) {
|
|||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_transpose_reshape_fusion", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWTransposeReshapeFusion, test_transpose_reshape_no_fusion) {
|
||||
/*
|
||||
* def before(input0, input1):
|
||||
* reshape = Reshape(input0, input1)
|
||||
* transpose = Transpose(reshape)
|
||||
* return transpose
|
||||
*/
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_reshape_fusion", "before");
|
||||
std::vector<int> shp{2, 4, 8, 16};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list{x_abstract};
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::TransposeReshapeFusion>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
||||
EXPECT_TRUE(CheckEqualGraph(kg, new_graph));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -36,7 +36,7 @@ def test_reshape_transpose_fusion(tag):
|
|||
|
||||
@fns
|
||||
def before(input0):
|
||||
reshape = Reshape(input0, (2, 4, 8, 16))
|
||||
reshape = Reshape(input0, (2, 2, 16, 16))
|
||||
transpose = Transpose(reshape, (1, 0, 2, 3))
|
||||
return transpose
|
||||
|
||||
|
|
Loading…
Reference in New Issue