diff --git a/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.cc b/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.cc index 0b77a68a0b5..87b1fcb002c 100644 --- a/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.cc +++ b/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.cc @@ -21,8 +21,8 @@ #include "backend/common/session/anf_runtime_algorithm.h" #include "include/common/utils/utils.h" #include "include/common/utils/anfalgo.h" -#include "kernel/common_utils.h" #include "include/common/utils/convert_utils.h" +#include "kernel/common_utils.h" namespace mindspore { namespace opt { @@ -240,6 +240,23 @@ void SetKernelInfoForValueNode(const ValueNodePtr &value_node) { AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), value_node.get()); } +abstract::AbstractBasePtr GenerateAbsByOpInfer(const PrimitivePtr &primitive, const AnfNodePtrList &input_list) { + MS_EXCEPTION_IF_NULL(primitive); + auto found = abstract::GetPrimitiveInferImpl(primitive); + if (!found.has_value()) { + MS_LOG(EXCEPTION) << primitive->name() << " infer is not registered."; + } + + std::vector input_args; + std::for_each(input_list.begin(), input_list.end(), + [&input_args](const auto &input) { input_args.emplace_back(input->abstract()); }); + auto infer_impl = found.value(); + auto abs = infer_impl.InferShapeAndType(nullptr, primitive, input_args); + MS_EXCEPTION_IF_NULL(abs); + MS_LOG(DEBUG) << "Abstract for " << primitive->name() << " is " << abs->ToString(); + return abs; +} + abstract::AbstractBasePtr GenerateAbsByUserNodeInput(const CNodePtr &user_node, size_t input_index) { MS_EXCEPTION_IF_NULL(user_node); auto shape = AnfAlgo::GetInputDeviceShape(user_node, input_index); @@ -507,18 +524,17 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleUnfoldToTensor(const FuncGraph CNodePtr tuple_to_tensor = func_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(tuple_to_tensor); - // Set abstract for TupleToTensor op according to user node's input shape and type. + // Data type of the tensor should be set as an attr of TupleToTensor op. size_t input_index = GetInputNodeIndex(input, node); - auto abs = GenerateAbsByUserNodeInput(node, input_index); + auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index); + common::AnfAlgo::SetNodeAttr("dtype", TypeIdToType(data_type), tuple_to_tensor); + + // Set abstract for TupleToTensor op according to user node's input shape and type. + auto abs = GenerateAbsByOpInfer(prim::kPrimTupleToTensor, {input}); MS_EXCEPTION_IF_NULL(abs); MS_LOG(DEBUG) << "Abstract for TupleToTensor op is " << abs->ToString(); tuple_to_tensor->set_abstract(abs); - // Data type of the tensor should be set as an attr of TupleToTensor op. - auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index); - // Attr name is to be confirmed. - common::AnfAlgo::SetNodeAttr("tensor_type", MakeValue(static_cast(data_type)), tuple_to_tensor); - SetKernelInfoForNewCNode(tuple_to_tensor); // Set object type to TUPLE for TupleUnfoldToTuple pattern to be matched. KernelBuildInfoPtr tuple_to_tensor_build_info = AnfAlgo::GetSelectKernelBuildInfo(tuple_to_tensor); @@ -571,18 +587,17 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTupleToTensor(const FuncGraphPtr &f CNodePtr tuple_to_tensor = func_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(tuple_to_tensor); - // Set abstract for TupleToTensor op according to user node's input shape and type. + // Data type of the tensor should be set as an attr of TupleToTensor op. size_t input_index = GetInputNodeIndex(input, node); - auto abs = GenerateAbsByUserNodeInput(node, input_index); + auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index); + common::AnfAlgo::SetNodeAttr("dtype", TypeIdToType(data_type), tuple_to_tensor); + + // Set abstract for TupleToTensor op according to user node's input shape and type. + auto abs = GenerateAbsByOpInfer(prim::kPrimTupleToTensor, {input}); MS_EXCEPTION_IF_NULL(abs); MS_LOG(DEBUG) << "Abstract for TupleToTensor op is " << abs->ToString(); tuple_to_tensor->set_abstract(abs); - // Data type of the tensor should be set as an attr of TupleToTensor op. - auto data_type = AnfAlgo::GetInputDeviceDataType(node, input_index); - // Attr name is to be confirmed. - common::AnfAlgo::SetNodeAttr("tensor_type", MakeValue(static_cast(data_type)), tuple_to_tensor); - SetKernelInfoForNewCNode(tuple_to_tensor); return {tuple_to_tensor}; } @@ -599,12 +614,11 @@ AnfNodePtrList InsertTypeTransformOp::ProcessScalarToTensor(const FuncGraphPtr & CNodePtr scalar_to_tensor = func_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(scalar_to_tensor); - // Set abstract for ScalarToTensor op according to user node's input shape and type. - size_t input_index = GetInputNodeIndex(input, node); - auto abs = GenerateAbsByUserNodeInput(node, input_index); + auto abs = GenerateAbsByOpInfer(prim::kPrimScalarToTensor, {input}); MS_EXCEPTION_IF_NULL(abs); MS_LOG(DEBUG) << "Abstract for ScalarToTensor op is " << abs->ToString(); scalar_to_tensor->set_abstract(abs); + SetKernelInfoForNewCNode(scalar_to_tensor); return {scalar_to_tensor}; } @@ -621,9 +635,7 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTensorToTuple(const FuncGraphPtr &f CNodePtr tensor_to_tuple = func_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(tensor_to_tuple); - // Set abstract for TensorToTuple op according to user node's input shape and type. - size_t input_index = GetInputNodeIndex(input, node); - auto abs = GenerateAbsByUserNodeInput(node, input_index); + auto abs = GenerateAbsByOpInfer(prim::kPrimTensorToTuple, {input}); MS_EXCEPTION_IF_NULL(abs); MS_LOG(DEBUG) << "Abstract for TensorToTuple op is " << abs->ToString(); tensor_to_tuple->set_abstract(abs); @@ -644,24 +656,7 @@ AnfNodePtrList InsertTypeTransformOp::ProcessTensorToScalar(const FuncGraphPtr & CNodePtr tensor_to_scalar = func_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(tensor_to_scalar); - // Set abstract for TensorToScalar op according to user node's input shape and type. - size_t input_index = GetInputNodeIndex(input, node); - auto input_node = common::AnfAlgo::GetInputNode(node, input_index); - auto origin_input_abs = input_node->abstract(); - MS_EXCEPTION_IF_NULL(origin_input_abs); - abstract::AbstractScalarPtr abs = nullptr; - if (input_node->isa()) { - ValuePtr tensor_value = origin_input_abs->BuildValue(); - if (!tensor_value->isa()) { - MS_LOG(EXCEPTION) << "The abstract of " << input_node->DebugString() << " should be a tensor value."; - } - ValuePtr scalar_value = CreateValueFromTensor(tensor_value->cast()); - MS_EXCEPTION_IF_NULL(scalar_value); - abs = std::make_shared(scalar_value, scalar_value->type()); - } else { - abs = std::make_shared( - kAnyValue, TypeIdToType(common::AnfAlgo::GetOutputInferDataType(input_node, kIndex0))); - } + auto abs = GenerateAbsByOpInfer(prim::kPrimTensorToScalar, {input}); MS_EXCEPTION_IF_NULL(abs); MS_LOG(DEBUG) << "Abstract for TensorToScalar op is " << abs->ToString(); tensor_to_scalar->set_abstract(abs); diff --git a/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.h b/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.h index 0b84b786db1..50f1c7800c7 100644 --- a/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.h +++ b/mindspore/ccsrc/backend/common/pass/insert_type_transform_op.h @@ -100,6 +100,9 @@ void SetKernelInfoForNewCNode(const CNodePtr &cnode, bool set_format_type = true // Set kernel info for some value nodes manually. void SetKernelInfoForValueNode(const ValueNodePtr &value_node); +// Multiplex op infer methods defined under core/ops to generate abstract of new cnode. +abstract::AbstractBasePtr GenerateAbsByOpInfer(const PrimitivePtr &primitive); + // Generate abstract, format and object type for newly created node. // They can be generated in multiple ways because new node is not processed by kernel selecting method. diff --git a/tests/ut/cpp/pre_activate/pass/insert_type_transform_op_test.cc b/tests/ut/cpp/pre_activate/pass/insert_type_transform_op_test.cc index 3e6d05e3aa9..9fe4b800538 100644 --- a/tests/ut/cpp/pre_activate/pass/insert_type_transform_op_test.cc +++ b/tests/ut/cpp/pre_activate/pass/insert_type_transform_op_test.cc @@ -361,9 +361,11 @@ TEST_F(TestInsertTypeTransformOp, test_tuple_to_tensor_transform) { ASSERT_TRUE(g != nullptr); std::vector shp_x{4, 2}; auto x_abstract = std::make_shared(kFloat32, shp_x); - std::vector shp_y{2}; + std::vector shp_y{1, 3}; auto y_abstract = std::make_shared(kFloat32, shp_y); - AbstractBasePtrList args_spec_list{x_abstract, y_abstract}; + AbstractBasePtrList abstract_list = {y_abstract}; + auto y_tuple_abs = std::make_shared(abstract_list); + AbstractBasePtrList args_spec_list{x_abstract, y_tuple_abs}; auto func_graph = GetFuncGraph(g, args_spec_list); ASSERT_TRUE(func_graph != nullptr); AnfNodePtr reshape;