diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc index 912d2dd38b9..cabc4a73fa7 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc @@ -236,6 +236,136 @@ MetaGraphTptr BuildMixGraph() { // final output return meta_graph; } +MetaGraphTptr BuildSplitGraph() { + auto meta_graph = std::make_shared(); + meta_graph->name = "graph"; + // slice node + auto split_node = std::make_unique(); + split_node->inputIndex = {0}; + split_node->outputIndex = {1, 2}; + split_node->primitive = std::make_unique(); + split_node->primitive->value.type = schema::PrimitiveType_Split; + std::unique_ptr attr = std::make_unique(); + attr->numberSplit = 2; + attr->splitDim = 1; + split_node->primitive->value.value = attr.release(); + split_node->name = "split"; + meta_graph->nodes.emplace_back(std::move(split_node)); + + meta_graph->inputIndex = {0, 3, 4}; + meta_graph->outputIndex = {5, 6}; + + auto mul_node1 = std::make_unique(); + mul_node1->inputIndex = {1, 3}; + mul_node1->outputIndex = {5}; + mul_node1->primitive = std::make_unique(); + mul_node1->primitive->value.type = schema::PrimitiveType_Mul; + std::unique_ptr mul_attr = std::make_unique(); + mul_node1->primitive->value.value = mul_attr.release(); + mul_node1->name = "mul1"; + meta_graph->nodes.emplace_back(std::move(mul_node1)); + + auto mul_node2 = std::make_unique(); + mul_node2->inputIndex = {2, 4}; + mul_node2->outputIndex = {6}; + mul_node2->primitive = std::make_unique(); + mul_node2->primitive->value.type = schema::PrimitiveType_Mul; + std::unique_ptr mul2_attr = std::make_unique(); + mul_node2->primitive->value.value = mul2_attr.release(); + mul_node2->name = "mul2"; + meta_graph->nodes.emplace_back(std::move(mul_node2)); + + // input 0: data1 + auto input0 = std::make_unique(); + input0->nodeType = schema::NodeType::NodeType_ValueNode; + input0->format = schema::Format_NHWC; + input0->dataType = TypeId::kNumberTypeFloat32; + input0->dims = {1, 2, 2, 3}; + input0->offset = -1; + auto input0_data = new(std::nothrow) float[2 * 2 * 3]; + for (auto i = 0; i < 2 * 2 * 3; i++) { + input0_data[i] = i; + } + input0->data.resize(sizeof(float) * 2 * 2 * 3); + memcpy(input0->data.data(), input0_data, 2 * 2 * 3 * sizeof(float)); + delete[] input0_data; + meta_graph->allTensors.emplace_back(std::move(input0)); + + // split output1 + auto split_output1 = std::make_unique(); + split_output1->nodeType = schema::NodeType::NodeType_Parameter; + split_output1->format = schema::Format_NHWC; + split_output1->dataType = TypeId::kNumberTypeFloat32; + split_output1->dims = {1, 1, 2, 3}; + split_output1->offset = -1; + split_output1->data.resize(sizeof(float) * 1 * 2 * 3); + auto split_output_data1 = new(std::nothrow) float[1 * 2 * 3]; + memcpy(split_output1->data.data(), split_output_data1, 1 * 2 * 3 * sizeof(float)); + delete[] split_output_data1; + meta_graph->allTensors.emplace_back(std::move(split_output1)); + + // split output2 + auto split_output2 = std::make_unique(); + split_output2->nodeType = schema::NodeType::NodeType_Parameter; + split_output2->format = schema::Format_NHWC; + split_output2->dataType = TypeId::kNumberTypeFloat32; + split_output2->dims = {1, 1, 2, 3}; + split_output2->offset = -1; + split_output2->data.resize(sizeof(float) * 1 * 2 * 3); + auto split_output_data2 = new(std::nothrow) float[1 * 2 * 3]; + memcpy(split_output2->data.data(), split_output_data2, 1 * 2 * 3 * sizeof(float)); + delete[] split_output_data2; + meta_graph->allTensors.emplace_back(std::move(split_output2)); + + // input 1: data2 + auto input1 = std::make_unique(); + input1->nodeType = schema::NodeType::NodeType_ValueNode; + input1->format = schema::Format_NHWC; + input1->dataType = TypeId::kNumberTypeFloat32; + input1->dims = {1, 1, 2, 3}; + input1->offset = -1; + input1->data.resize(sizeof(float) * 2 * 3); + auto input1_data = new(std::nothrow) float[2 * 3]; + for (auto i = 0; i < 2 * 3; i++) { + input1_data[i] = i; + } + memcpy(input1->data.data(), input1_data, 2 * 3 * sizeof(float)); + delete[] input1_data; + meta_graph->allTensors.emplace_back(std::move(input1)); + + // input 2: data3 + auto input2 = std::make_unique(); + input2->nodeType = schema::NodeType::NodeType_ValueNode; + input2->format = schema::Format_NHWC; + input2->dataType = TypeId::kNumberTypeFloat32; + input2->dims = {1, 1, 2, 3}; + input2->offset = -1; + input2->data.resize(sizeof(float) * 2 * 3); + auto input2_data = new(std::nothrow) float[2 * 3]; + for (auto i = 0; i < 2 * 3; i++) { + input2_data[i] = 10; + } + memcpy(input2->data.data(), input2_data, 2 * 3 * sizeof(float)); + delete[] input2_data; + meta_graph->allTensors.emplace_back(std::move(input2)); + + // final mul output1 + auto mul_output = std::make_unique(); + mul_output->nodeType = schema::NodeType::NodeType_Parameter; + mul_output->format = schema::Format_NHWC; + mul_output->dataType = TypeId::kNumberTypeFloat32; + mul_output->dims = {1, 1, 2, 3}; + meta_graph->allTensors.emplace_back(std::move(mul_output)); + + // final mul output2 + auto mul_output2 = std::make_unique(); + mul_output2->nodeType = schema::NodeType::NodeType_Parameter; + mul_output2->format = schema::Format_NHWC; + mul_output2->dataType = TypeId::kNumberTypeFloat32; + mul_output2->dims = {1, 1, 2, 3}; + meta_graph->allTensors.emplace_back(std::move(mul_output2)); + return meta_graph; +} } // namespace TEST_F(ConstantFoldingFusionTest, TestADDConstantFold) { auto meta_graph = BuildGraph(schema::PrimitiveType_Add, new schema::AddT); @@ -483,4 +613,19 @@ TEST_F(ConstantFoldingFusionTest, TestCastDimsConstantFold) { auto new_meta_graph = lite::Export(new_graph); ASSERT_EQ(new_meta_graph->nodes.size(), 0); } + +TEST_F(ConstantFoldingFusionTest, TestSplitConstantFold) { + auto meta_graph = BuildSplitGraph(); + auto input_tensor = meta_graph->allTensors.at(0).get(); + input_tensor->dataType = kNumberTypeFloat32; + auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get()); + auto optimizer = std::make_shared(); + auto pm = std::make_shared("test", false); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(func_graph); + ASSERT_NE(nullptr, new_graph); + auto new_meta_graph = lite::Export(new_graph); + ASSERT_EQ(new_meta_graph->nodes.size(), 0); +} } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index 62f2613c32e..db1d538d0af 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -319,7 +319,7 @@ schema::PrimitiveType GetCNodeType(const BaseRef &n) { if (utils::isa(value)) { auto primitive = value->cast(); MS_ASSERT(primitive != nullptr); - return (schema::PrimitiveType)primitive->Type(); + return (schema::PrimitiveType) primitive->Type(); } else if (utils::isa(value)) { auto primitive = value->cast(); MS_ASSERT(primitive != nullptr); @@ -392,8 +392,8 @@ size_t GetOutputTensorNum(const AnfNodePtr &node) { bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node) { auto output_node_list = GetRealNodeUsedList(graph, node); if (output_node_list->size() != 1) { - MS_LOG(DEBUG) << "fusion node has multi output nodes"; - return true; + MS_LOG(DEBUG) << "fusion node has multi output nodes"; + return true; } return false; } @@ -412,5 +412,50 @@ std::shared_ptr>> GetRealNodeUsedList(con std::copy(output_info_list.begin(), output_info_list.end(), std::back_inserter(*output_node_list)); return output_node_list; } +size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) { + MS_ASSERT(tuple_get_item != nullptr); + if (tuple_get_item->size() != kTupleGetItemInputSize) { + MS_LOG(ERROR) << "The node tuple_get_item must have 2 inputs!"; + return -1; + } + auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem); + MS_ASSERT(output_index_value_node != nullptr); + auto value_node = output_index_value_node->cast(); + MS_ASSERT(value_node != nullptr); + return IntToSize(GetValue(value_node->value())); +} +std::shared_ptr>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph, + const AnfNodePtr &node, + size_t output_index) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(node != nullptr); + auto output_node_list = std::make_shared>>(); + auto manager = graph->manager(); + MS_ASSERT(manager != nullptr); + auto iter = manager->node_users().find(node); + if (iter == manager->node_users().end()) { + MS_LOG(ERROR) << "node has no output in manager"; + return output_node_list; + } + auto output_info_list = iter->second; + for (const auto &output_info : output_info_list) { + size_t used_output_index; + if (GetCNodeType(output_info.first) == schema::PrimitiveType_TupleGetItem) { + used_output_index = GetTupleGetItemOutIndex(utils::cast(output_info.first)); + } else if (GetCNodeType(node) == schema::PrimitiveType_TupleGetItem) { + used_output_index = output_index; + } else { + if (output_index != 0) { + MS_LOG(ERROR) << "node has no output in manager"; + return output_node_list; + } + return output_node_list; + } + if (used_output_index == output_index) { + output_node_list->push_back(output_info); + } + } + return output_node_list; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index 066882caca0..4554857ccd8 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -63,6 +63,8 @@ bool CheckIsAllInputsParam(const AnfNodePtr &node); size_t GetOutputTensorNum(const AnfNodePtr &node); bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node); + +size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item); } // namespace opt } // namespace mindspore #endif // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index 992e91dd1fa..bd7d4017716 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -41,7 +41,7 @@ std::vector GetCNodeInputTensors(const CNodePtr &CNode) { auto tensorT = tmp_meta_graph->allTensors.at(input_index).get(); auto tensor_shape = tensorT->dims; auto lite_tensor = - new (std::nothrow) Tensor(TypeId(tensorT->dataType), tensor_shape, tensorT->format, tensorT->nodeType); + new (std::nothrow) Tensor(TypeId(tensorT->dataType), tensor_shape, tensorT->format, tensorT->nodeType); if (lite_tensor == nullptr) { MS_LOG(ERROR) << "lite tensor is nullptr"; return input_tensors; @@ -106,7 +106,7 @@ kernel::LiteKernel *GetLiteKernel(std::vector inputs, std::vectordata_type(); - kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, (schema::PrimitiveType)primitive->Type()}; + kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, (schema::PrimitiveType) primitive->Type()}; lite::Context context; auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); if (creator != nullptr) { @@ -115,6 +115,44 @@ kernel::LiteKernel *GetLiteKernel(std::vector inputs, std::vector output_tensors, size_t replace_index) { + MS_ASSERT(func_graph != nullptr); + auto manager = func_graph->manager(); + MS_ASSERT(manager != nullptr); + if (output_tensors.size() != 1) { + for (size_t k = 0; k < output_tensors.size(); k++) { + auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, input_node, k); + if (used_node_list->size() != 1) { + MS_LOG(ERROR) << " output must tuple_getitem"; + return lite::RET_ERROR; + } + auto tuple_node = used_node_list->at(0).first; + if (GetCNodeType(tuple_node) == schema::PrimitiveType_TupleGetItem) { + auto new_parameter = CreateNewParamter(func_graph, output_tensors.at(k)); + if (new_parameter == nullptr) { + MS_LOG(ERROR) << "CreateNewParamter failed, name: " << input_node->fullname_with_scope(); + return lite::RET_ERROR; + } + new_parameter->set_name(input_node->fullname_with_scope() + "_const_" + std::to_string(k)); + manager->Replace(tuple_node, new_parameter); + } else { + MS_LOG(ERROR) << " multi out tensor must connect tuple-getitem: " << input_node->fullname_with_scope(); + return lite::RET_ERROR; + } + } + } else { + auto new_parameter = CreateNewParamter(func_graph, output_tensors.front()); + if (new_parameter == nullptr) { + MS_LOG(ERROR) << "CreateNewParamter failed, name: " << input_node->fullname_with_scope(); + return lite::RET_ERROR; + } + new_parameter->set_name(input_node->fullname_with_scope()); + any_node->set_input(replace_index, new_parameter); + } + return lite::RET_OK; +} } // namespace void FreeTensors(std::vector *input_tensor, std::vector *output_tensor) { if (input_tensor != nullptr) { @@ -140,64 +178,66 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An } auto any_node = node->cast(); CheckIfCNodeIsNull(any_node); + bool changed = false; for (size_t i = 1; i < any_node->inputs().size(); i++) { auto input_node = any_node->input(i); - if (input_node->isa() && CheckIsAllInputsParam(input_node)) { - auto input_cnode = input_node->cast(); - auto input_tensors = GetCNodeInputTensors(input_cnode); - if (input_tensors.empty() || input_tensors.size() != input_cnode->inputs().size() - 1) { - FreeTensors(&input_tensors, nullptr); - continue; + if (!input_node->isa() || !CheckIsAllInputsParam(input_node)) { + continue; + } + auto input_cnode = input_node->cast(); + auto input_tensors = GetCNodeInputTensors(input_cnode); + if (input_tensors.empty() || input_tensors.size() != input_cnode->inputs().size() - 1) { + FreeTensors(&input_tensors, nullptr); + continue; + } + changed = true; + MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope(); + auto output_nums = GetOutputTensorNum(input_cnode); + std::vector output_tensors{output_nums, new Tensor()}; + auto lite_primitive = GetValueNode>(input_cnode->input(0)); + if (lite_primitive == nullptr) { + MS_LOG(ERROR) << "lite_primitive is nullptr"; + FreeTensors(&input_tensors, &output_tensors); + return nullptr; + } + // here, input_tensor's format need to be transposed nhwc according to fmkType, + // but for the time being, we only transpose the tensor with 0/1/2/3D. + // Others should be added in future. + for (size_t j = 0; j < input_tensors.size(); ++j) { + input_tensors[j]->SetFormat(schema::Format_NHWC); + if (input_tensors[j]->shape().size() == 4) { + MS_LOG(INFO) << "init input_tensor format to nhwc"; } - MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope(); - auto output_nums = GetOutputTensorNum(input_cnode); - std::vector output_tensors{output_nums, new Tensor()}; - auto lite_primitive = GetValueNode>(input_cnode->input(0)); - if (lite_primitive == nullptr) { - MS_LOG(ERROR) << "lite_primitive is nullptr"; - FreeTensors(&input_tensors, &output_tensors); - return nullptr; - } - // here, input_tensor's format need to be transposed nhwc according to fmkType, - // but for the time being, we only transpose the tensor with 0/1/2/3D. - // Others should be added in future. - for (size_t j = 0; j < input_tensors.size(); ++j) { - input_tensors[j]->SetFormat(schema::Format_NHWC); - if (input_tensors[j]->shape().size() == 4) { - MS_LOG(INFO) << "init input_tensor format to nhwc"; - } - } - lite_primitive->InferShape(input_tensors, output_tensors); - auto parameter = kernel::PopulateParameter(lite_primitive.get()); - if (parameter == nullptr) { - MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " - << schema::EnumNamePrimitiveType((schema::PrimitiveType)(lite_primitive->Type())); - return nullptr; - } - auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, lite_primitive.get()); - if (lite_kernel == nullptr) { - MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr"; - FreeTensors(&input_tensors, &output_tensors); - return nullptr; - } - auto ret = lite_kernel->Run(); - if (0 != ret) { - FreeTensors(&input_tensors, &output_tensors); - MS_LOG(ERROR) << "run kernel failed, name: " << lite_kernel->name(); - return nullptr; - } - auto new_parameter = CreateNewParamter(func_graph, output_tensors.front()); - if (new_parameter == nullptr) { - FreeTensors(&input_tensors, &output_tensors); - MS_LOG(ERROR) << "CreateNewParamter failed, name: " << lite_kernel->name(); - return nullptr; - } - new_parameter->set_name(input_node->fullname_with_scope()); - any_node->set_input(i, new_parameter); + } + lite_primitive->InferShape(input_tensors, output_tensors); + auto parameter = kernel::PopulateParameter(lite_primitive.get()); + if (parameter == nullptr) { + MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " + << schema::EnumNamePrimitiveType((schema::PrimitiveType) (lite_primitive->Type())); + return nullptr; + } + auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, lite_primitive.get()); + if (lite_kernel == nullptr) { + MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr"; + FreeTensors(&input_tensors, &output_tensors); + return nullptr; + } + auto ret = lite_kernel->Run(); + if (0 != ret) { + FreeTensors(&input_tensors, &output_tensors); + MS_LOG(ERROR) << "run kernel failed, name: " << lite_kernel->name(); + return nullptr; + } + // replace cnode by new param + if (ReplaceCNode(func_graph, any_node, input_node, output_tensors, i) != lite::RET_OK) { FreeTensors(&input_tensors, &output_tensors); delete (lite_kernel); + MS_LOG(ERROR) << "constant_folding replace cnode failed"; + return nullptr; } + FreeTensors(&input_tensors, &output_tensors); + delete (lite_kernel); } - return any_node; + return changed ? any_node : nullptr; } } // namespace mindspore::opt