From bf3db898b8e1a4f167cb13470523013fd66acc69 Mon Sep 17 00:00:00 2001 From: ZPaC Date: Sat, 17 Dec 2022 17:51:34 +0800 Subject: [PATCH] Add tuple to tensor pattern --- .../common/pass/insert_type_transform_op.cc | 25 +++++++++ .../common/pass/insert_type_transform_op.h | 4 ++ mindspore/core/ops/core_ops.h | 8 ++- .../pass/insert_type_transform_op_test.cc | 55 +++++++++++++++++-- .../insert_type_transform_op_test.py | 24 ++++++++ 5 files changed, 110 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.cc b/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.cc index 00f1f75f46d..c7a6a1a62de 100644 --- a/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.cc +++ b/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.cc @@ -260,6 +260,9 @@ InsertTypeTransformOp::InsertTypeTransformOp(bool multigraph) kTypePairToProcessFunc[{KernelObjectType::TUPLE_UNFOLD, KernelObjectType::TENSOR}] = std::bind(&InsertTypeTransformOp::ProcessTupleUnfoldToTensor, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4); + kTypePairToProcessFunc[{KernelObjectType::TUPLE, KernelObjectType::TENSOR}] = + std::bind(&InsertTypeTransformOp::ProcessTupleToTensor, this, std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4); } const AnfNodePtr InsertTypeTransformOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, @@ -414,5 +417,27 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleUnfoldToTensor(const FuncGraph return {tuple_to_tensor}; } + +AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &input, + const CNodePtr &node, bool *new_prim) { + MS_EXCEPTION_IF_NULL(input); + MS_EXCEPTION_IF_NULL(node); + + // Simply insert TupleToTensor op between 'input' and 'node'. + auto prim = NewValueNode(prim::kPrimTupleToTensor); + MS_EXCEPTION_IF_NULL(prim); + AnfNodePtrList inputs = {prim, input}; + CNodePtr tuple_to_tensor = func_graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(tuple_to_tensor); + + // Set abstract for TupleToTensor op according to user node's input shape and type. + size_t input_index = GetInputNodeIndex(input, node); + auto abs = GenerateAbsByUserNodeInput(node, input_index); + MS_EXCEPTION_IF_NULL(abs); + MS_LOG(DEBUG) << "Abstract for TupleToTensor op is " << abs->ToString(); + tuple_to_tensor->set_abstract(abs); + SetKernelInfoForNewCNode(tuple_to_tensor); + return {tuple_to_tensor}; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.h b/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.h index f37a4d7ea0c..860a1c1ccf0 100644 --- a/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.h +++ b/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.h @@ -127,6 +127,10 @@ class BACKEND_EXPORT InsertTypeTransformOp : public PatternProcessPass { // Convert TupleUnfold output to Tensor. Firstly insert TupleToTensor op. Then transform TupleUnfold to Tuple. AnfNodePtrList ProcessTupleUnfoldToTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const CNodePtr &node, bool *new_prim); + + // Convert Tuple output to Tensor. Simply add TupleToTensor op. + AnfNodePtrList ProcessTupleToTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const CNodePtr &node, + bool *new_prim); }; } // namespace opt } // namespace mindspore diff --git a/mindspore/core/ops/core_ops.h b/mindspore/core/ops/core_ops.h index 91980bf2f82..3c21c6ca049 100644 --- a/mindspore/core/ops/core_ops.h +++ b/mindspore/core/ops/core_ops.h @@ -460,6 +460,7 @@ constexpr auto kRealTupleGetItem = "RealTupleGetItem"; constexpr auto kRealMakeList = "RealMakeList"; constexpr auto kRealListGetItem = "RealListGetItem"; constexpr auto kTupleSetItem = "TupleSetItem"; +constexpr auto kTupleAdd = "TupleAdd"; GVAR_DEF(PrimitivePtr, kPrimExtractGlimpse, std::make_shared(kExtractGlimpse)); // @@ -1479,7 +1480,6 @@ GVAR_DEF(PrimitivePtr, kPrimRaise, std::make_shared("raise")); GVAR_DEF(PrimitivePtr, kPrimJoinedStr, std::make_shared("joinedstr")); GVAR_DEF(PrimitivePtr, kPrimMakeTuple, std::make_shared(kMakeTuple)); -GVAR_DEF(PrimitivePtr, kPrimRealMakeTuple, std::make_shared(kRealMakeTuple)); GVAR_DEF(PrimitivePtr, kPrimMakeSlice, std::make_shared("make_slice")); GVAR_DEF(PrimitivePtr, kPrimTupleGetItem, std::make_shared(kTupleGetItem)); GVAR_DEF(PrimitivePtr, kPrimSliceGetItem, std::make_shared(kSliceGetItem)); @@ -1681,7 +1681,6 @@ GVAR_DEF(PrimitivePtr, kPrimTupleReversed, std::make_shared("tuple_re GVAR_DEF(PrimitivePtr, kPrimReducedShape, std::make_shared("reduced_shape")); GVAR_DEF(PrimitivePtr, kPrimTupleDiv, std::make_shared("tuple_div")); GVAR_DEF(PrimitivePtr, kPrimTupleToArray, std::make_shared("tuple_to_array")); -GVAR_DEF(PrimitivePtr, kPrimTupleToTensor, std::make_shared(kTupleToTensor)); GVAR_DEF(PrimitivePtr, kPrimShapeMul, std::make_shared("shape_mul")); GVAR_DEF(PrimitivePtr, kPrimTupleEqual, std::make_shared("tuple_equal")); GVAR_DEF(PrimitivePtr, kPrimListEqual, std::make_shared("list_equal")); @@ -1730,6 +1729,11 @@ GVAR_DEF(PrimitivePtr, kPrimReservoirReplayBufferSample, std::make_shared("ReservoirReplayBufferDestroy")); GVAR_DEF(PrimitivePtr, kPrimOCRDetectionPreHandle, std::make_shared("OCRDetectionPreHandle")); +// Real tuple and list ops. +GVAR_DEF(PrimitivePtr, kPrimTupleToTensor, std::make_shared(kTupleToTensor)); +GVAR_DEF(PrimitivePtr, kPrimRealMakeTuple, std::make_shared(kRealMakeTuple)); +GVAR_DEF(PrimitivePtr, kPrimTupleAdd, std::make_shared(kTupleAdd)); + // AdamApplyOne GVAR_DEF(PrimitivePtr, kPrimAdamApplyOne, std::make_shared("AdamApplyOne")); GVAR_DEF(PrimitivePtr, kPrimAdamApplyOneAssign, std::make_shared("AdamApplyOneAssign")); diff --git a/tests/ut/cpp/pre_activate/pass/insert_type_transform_op_test.cc b/tests/ut/cpp/pre_activate/pass/insert_type_transform_op_test.cc index cbef20e69ef..169a8c888ca 100644 --- a/tests/ut/cpp/pre_activate/pass/insert_type_transform_op_test.cc +++ b/tests/ut/cpp/pre_activate/pass/insert_type_transform_op_test.cc @@ -46,6 +46,7 @@ class TestInsertTypeTransformOp : public BackendCommon { void SetTupleUnfoldToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple_ptr, AnfNodePtr *split_ptr, AnfNodePtr *tuple_add1_ptr, AnfNodePtr *tuple_add2_ptr); void SetTupleUnfoldToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple, AnfNodePtr *reshape); + void SetTupleToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *reshape_ptr); void SetKernelBuildInfo(const AnfNodePtr &node, const std::vector &input_formats, const std::vector &input_types, const std::vector &output_formats, @@ -165,7 +166,7 @@ void TestInsertTypeTransformOp::SetTupleUnfoldToTensorKernelBuildInfo(const Func {kNumberTypeFloat32}, {KernelObjectType::TENSOR, KernelObjectType::TENSOR}, {KernelObjectType::TENSOR}); - auto input_node3 = reshape->input(0); + auto input_node3 = reshape->input(1); SetKernelBuildInfo(input_node3, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32}, {KernelObjectType::TENSOR}, {KernelObjectType::TENSOR}); @@ -184,6 +185,23 @@ void TestInsertTypeTransformOp::SetTupleUnfoldToTensorKernelBuildInfo(const Func {KernelObjectType::TENSOR}, {KernelObjectType::TENSOR}); } +void TestInsertTypeTransformOp::SetTupleToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *reshape_ptr) { + auto ret = g->get_return(); + EXPECT_NE(ret->input(1), nullptr); + auto reshape = ret->input(1)->cast(); + *reshape_ptr = reshape; + SetKernelBuildInfo(reshape, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32}, {"NCHW"}, + {kNumberTypeFloat32}, {KernelObjectType::TENSOR, KernelObjectType::TENSOR}, + {KernelObjectType::TENSOR}); + auto input2 = reshape->input(2); + SetKernelBuildInfo(input2, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32}, {KernelObjectType::TENSOR}, + {KernelObjectType::TUPLE}); + + auto input1 = reshape->input(1); + SetKernelBuildInfo(input1, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32}, {KernelObjectType::TENSOR}, + {KernelObjectType::TENSOR}); +} + void TestInsertTypeTransformOp::SetKernelBuildInfo( const AnfNodePtr &node, const std::vector &input_formats, const std::vector &input_types, const std::vector &output_formats, const std::vector &output_types, @@ -219,7 +237,8 @@ void TestInsertTypeTransformOp::CheckOutputKernelInfo(const AnfNodePtr &node, si /// Feature: Dynamic shape. /// Description: Test TupleUnfold to TupleUnfold type transforming pass. -/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph expressed by python. +/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph +/// expressed by python. TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_unfold_transform) { FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tuple_unfold_transform", "before"); ASSERT_TRUE(g != nullptr); @@ -250,7 +269,8 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_unfold_transform) { /// Feature: Dynamic shape. /// Description: Test TupleUnfold to Tuple type transforming pass. -/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph expressed by python. +/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph +/// expressed by python. TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_transform) { FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tuple_transform", "before"); ASSERT_TRUE(g != nullptr); @@ -277,7 +297,8 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_transform) { /// Feature: Dynamic shape. /// Description: Test TupleUnfold to Tensor type transforming pass. -/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph expressed by python. +/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph +/// expressed by python. TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tensor_transform) { FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tensor_transform", "before"); ASSERT_TRUE(g != nullptr); @@ -303,5 +324,31 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tensor_transform) { ASSERT_TRUE(g_after != nullptr); EXPECT_TRUE(CheckEqualGraph(func_graph, g_after)); } + +/// Description: Test Tuple to Tensor type transforming pass. +/// Expectation: After InsertTypeTransformOp pass, the graph is identical to the expected graph expressed by python. +TEST_F(TestInsertTypeTransformOp, test_tuple_to_tensor_transform) { + FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_to_tensor_transform", "before"); + ASSERT_TRUE(g != nullptr); + std::vector shp_x{4, 2}; + auto x_abstract = std::make_shared(kFloat32, shp_x); + std::vector shp_y{2}; + auto y_abstract = std::make_shared(kFloat32, shp_y); + AbstractBasePtrList args_spec_list{x_abstract, y_abstract}; + auto func_graph = GetFuncGraph(g, args_spec_list); + ASSERT_TRUE(func_graph != nullptr); + AnfNodePtr reshape; + SetTupleToTensorKernelBuildInfo(func_graph, &reshape); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + optimizer->Optimize(func_graph); + + FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_tuple_to_tensor_transform", "after"); + ASSERT_TRUE(g_after != nullptr); + EXPECT_TRUE(CheckEqualGraph(func_graph, g_after)); +} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_type_transform_op_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_type_transform_op_test.py index 1f8c4b1d972..03c025fb838 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_type_transform_op_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_type_transform_op_test.py @@ -124,3 +124,27 @@ def test_tuple_unfold_to_tensor_transform(tag): return res return fns[tag] + + +def test_tuple_to_tensor_transform(tag): + """ + Feature: Dynamic shape. + Description: Test Tuple to Tensor transforming pass. + Expectation: The 'after' graph is identical to the graph after this pass. + """ + fns = FnDict() + reshape = P.Reshape() + tuple_to_tensor = Primitive('TupleToTensor') + + @fns + def before(x, y): + res = reshape(x, y) + return res + + @fns + def after(x, y): + res = tuple_to_tensor(y) + res = reshape(x, res) + return res + + return fns[tag]