From 50f90fb8df5bd7d792c833741f035ce5c99279a1 Mon Sep 17 00:00:00 2001 From: ZPaC Date: Tue, 7 Feb 2023 19:27:32 +0800 Subject: [PATCH] Reset RealMakeTuple objtypes --- .../common/pass/insert_type_transform_op.cc | 11 +++++++ .../pass/insert_type_transform_op_test.cc | 33 ++++++------------- .../insert_type_transform_op_test.py | 15 +++++---- 3 files changed, 29 insertions(+), 30 deletions(-) diff --git a/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.cc b/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.cc index c736ba08ddb..73cca82ec41 100644 --- a/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.cc +++ b/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.cc @@ -139,6 +139,17 @@ AnfNodePtr CreateRealMakeTupleByMakeTuple(const FuncGraphPtr &func_graph, const real_make_tuple->set_abstract(make_tuple_node->abstract()); SetKernelInfoForNewCNode(real_make_tuple); + + // RealMakeTuple's inputs must be scalar. To avoid failing to select kernel, we must override RealMakeTuple's + // KernelObjectTypes, which is created from MakeTuple. + KernelBuildInfoPtr real_make_tuple_build_info = AnfAlgo::GetSelectKernelBuildInfo(real_make_tuple); + MS_EXCEPTION_IF_NULL(real_make_tuple_build_info); + auto inputs_obj_types = real_make_tuple_build_info->GetAllInputKernelObjectTypes(); + auto new_obj_types = inputs_obj_types; + std::transform(new_obj_types.begin(), new_obj_types.end(), new_obj_types.begin(), + [](const auto &obj_type) { return KernelObjectType::SCALAR; }); + real_make_tuple_build_info->SetInputsKernelObjectType(new_obj_types); + MS_LOG(DEBUG) << "Override RealMakeTuple input kernel object types from " << inputs_obj_types << " " << new_obj_types; return real_make_tuple; } diff --git a/tests/ut/cpp/pre_activate/pass/insert_type_transform_op_test.cc b/tests/ut/cpp/pre_activate/pass/insert_type_transform_op_test.cc index b265c399e0c..c582882ce8a 100644 --- a/tests/ut/cpp/pre_activate/pass/insert_type_transform_op_test.cc +++ b/tests/ut/cpp/pre_activate/pass/insert_type_transform_op_test.cc @@ -45,8 +45,8 @@ class TestInsertTypeTransformOp : public BackendCommon { public: void SetTupleUnfoldToTupleUnfoldKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *split1_ptr, AnfNodePtr *addn1_ptr, AnfNodePtr *split2_ptr, AnfNodePtr *addn2_ptr); - void SetTupleUnfoldToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple_ptr, AnfNodePtr *split_ptr, - AnfNodePtr *seq_add1_ptr, AnfNodePtr *seq_add2_ptr); + void SetTupleUnfoldToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple_ptr, + AnfNodePtr *seq_add1_ptr); void SetTupleUnfoldToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple, AnfNodePtr *reshape_ptr); void SetTupleToTupleUnfoldKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *shape_ptr); void SetTupleToTensorKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *reshape_ptr); @@ -126,23 +126,10 @@ void TestInsertTypeTransformOp::SetTupleUnfoldToTupleUnfoldKernelBuildInfo( } void TestInsertTypeTransformOp::SetTupleUnfoldToTupleKernelBuildInfo(const FuncGraphPtr &g, AnfNodePtr *make_tuple_ptr, - AnfNodePtr *split_ptr, AnfNodePtr *seq_add1_ptr, - AnfNodePtr *seq_add2_ptr) { + AnfNodePtr *seq_add1_ptr) { auto ret = g->get_return(); EXPECT_NE(ret->input(1), nullptr); - auto seq_add2 = ret->input(1)->cast(); - *seq_add2_ptr = seq_add2; - SetKernelBuildInfo(seq_add2, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32}, {"NCHW"}, - {kNumberTypeFloat32}, {KernelObjectType::TUPLE, KernelObjectType::TUPLE}, - {KernelObjectType::TENSOR}); - - auto split = seq_add2->input(1)->cast(); - *split_ptr = split; - MS_LOG(INFO) << "split is " << split->fullname_with_scope(); - SetKernelBuildInfo(split, {"NCHW"}, {kNumberTypeFloat32}, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32}, - {KernelObjectType::TENSOR}, {KernelObjectType::TUPLE_UNFOLD}); - - auto seq_add1 = split->input(1)->cast(); + auto seq_add1 = ret->input(1)->cast(); *seq_add1_ptr = seq_add1; SetKernelBuildInfo(seq_add1, {"NCHW", "NCHW"}, {kNumberTypeFloat32, kNumberTypeFloat32}, {"NCHW"}, {kNumberTypeFloat32}, {KernelObjectType::TUPLE, KernelObjectType::TUPLE}, @@ -336,9 +323,9 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_unfold_transform) { TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_transform) { FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tuple_transform", "before"); ASSERT_TRUE(g != nullptr); - std::vector shp_1{2, 4}; + std::vector shp_1{}; auto abstract_1 = std::make_shared(kFloat32, shp_1); - std::vector shp_2{2, 4}; + std::vector shp_2{}; auto abstract_2 = std::make_shared(kFloat32, shp_2); std::vector shp_x{1, 3}; auto x_abstract = std::make_shared(kFloat32, shp_x); @@ -347,8 +334,8 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_transform) { AbstractBasePtrList args_spec_list{abstract_1, abstract_2, x_tuple_abs}; auto func_graph = GetFuncGraph(g, args_spec_list); ASSERT_TRUE(func_graph != nullptr); - AnfNodePtr make_tuple, split, seq_add1, seq_add2; - SetTupleUnfoldToTupleKernelBuildInfo(func_graph, &make_tuple, &split, &seq_add1, &seq_add2); + AnfNodePtr make_tuple, seq_add1; + SetTupleUnfoldToTupleKernelBuildInfo(func_graph, &make_tuple, &seq_add1); auto optimizer = std::make_shared(); auto pm = std::make_shared(); @@ -374,9 +361,9 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tuple_transform) { TEST_F(TestInsertTypeTransformOp, test_tuple_unfold_to_tensor_transform) { FuncGraphPtr g = getPyFun_.CallAndParseRet("test_tuple_unfold_to_tensor_transform", "before"); ASSERT_TRUE(g != nullptr); - std::vector shp_x{4}; + std::vector shp_x{}; auto x_abstract = std::make_shared(kFloat32, shp_x); - std::vector shp_y{2}; + std::vector shp_y{}; auto y_abstract = std::make_shared(kFloat32, shp_y); std::vector shp_z{2, 4}; auto z_abstract = std::make_shared(kFloat32, shp_z); diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_type_transform_op_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_type_transform_op_test.py index 61391b8667f..5bb300d7a0f 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_type_transform_op_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_type_transform_op_test.py @@ -18,6 +18,7 @@ from mindspore.common.tensor import Tensor from mindspore.ops import Primitive from mindspore.ops import operations as P from mindspore.ops import _constants as Constants +from mindspore.ops.operations import _sequence_ops as seq tuple_get_item = Primitive(Constants.kTupleGetItem) @@ -76,25 +77,22 @@ def test_tuple_unfold_to_tuple_transform(tag): """ fns = FnDict() # Need to change AddN to SequenceAdd in later version. This case is just used to cover this pattern. - seq_add1 = P.AddN() - seq_add2 = P.AddN() + seq_add1 = seq.SequenceAdd() + tensor_to_scalar = Primitive('TensorToScalar') real_make_tuple = Primitive('RealMakeTuple') @fns def before(input_1, input_2, x): res = make_tuple(input_1, input_2) res = seq_add1(res, x) - res = split1(res) - res = seq_add2(res, x) return res @fns def after(input_1, input_2, x): + input_1 = tensor_to_scalar(input_1) + input_2 = tensor_to_scalar(input_2) res = real_make_tuple(input_1, input_2) res = seq_add1(res, x) - res = split1(res) - res = real_make_tuple(tuple_get_item(res, 0), tuple_get_item(res, 1)) - res = seq_add2(res, x) return res return fns[tag] @@ -110,6 +108,7 @@ def test_tuple_unfold_to_tensor_transform(tag): reshape = P.Reshape() real_make_tuple = Primitive('RealMakeTuple') tuple_to_tensor = Primitive('TupleToTensor') + tensor_to_scalar = Primitive('TensorToScalar') @fns def before(input_1, input_2, input_3): @@ -119,6 +118,8 @@ def test_tuple_unfold_to_tensor_transform(tag): @fns def after(input_1, input_2, input_3): + input_1 = tensor_to_scalar(input_1) + input_2 = tensor_to_scalar(input_2) res = real_make_tuple(input_1, input_2) res = tuple_to_tensor(res) res = reshape(input_3, res)