forked from mindspore-Ecosystem/mindspore
add anf const fold fusion
This commit is contained in:
parent
04371f6d38
commit
69c2ea82b3
|
@ -228,6 +228,7 @@ if(BUILD_CONVERTER)
|
|||
${LITE_DIR}/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc
|
||||
${LITE_DIR}/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc
|
||||
${LITE_DIR}/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc
|
||||
${LITE_DIR}/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc
|
||||
${LITE_DIR}/tools/optimizer/common/node_pass_extends.cc
|
||||
${LITE_DIR}/tools/optimizer/common/pass_manager_extends.cc
|
||||
${LITE_DIR}/tools/optimizer/common/gllo_utils.cc
|
||||
|
@ -236,6 +237,7 @@ if(BUILD_CONVERTER)
|
|||
${LITE_DIR}/tools/optimizer/fusion/conv_transform_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/conv_scale_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/conv_bn_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/constant_folding_fusion.cc
|
||||
)
|
||||
endif()
|
||||
### train
|
||||
|
|
|
@ -0,0 +1,486 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "include/model.h"
|
||||
#include "common/common_test.h"
|
||||
#include "include/lite_session.h"
|
||||
#include "include/context.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/anf_transform.h"
|
||||
#include "tools/optimizer/fusion/constant_folding_fusion.h"
|
||||
#include "src/common/anf_exporter/anf_exporter.h"
|
||||
|
||||
namespace mindspore {
|
||||
class ConstantFoldingFusionTest : public mindspore::CommonTest {
|
||||
public:
|
||||
ConstantFoldingFusionTest() = default;
|
||||
};
|
||||
using MetaGraphTptr = std::shared_ptr<schema::MetaGraphT>;
|
||||
using CNodeTptr = std::unique_ptr<schema::CNodeT>;
|
||||
|
||||
namespace {
|
||||
|
||||
MetaGraphTptr BuildGraph(schema::PrimitiveType op_type, void *op_node) {
|
||||
auto meta_graph = std::make_shared<schema::MetaGraphT>();
|
||||
meta_graph->name = "graph";
|
||||
// biasadd node
|
||||
auto example_node = std::make_unique<schema::CNodeT>();
|
||||
example_node->inputIndex = {0, 1};
|
||||
example_node->outputIndex = {2};
|
||||
example_node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
example_node->primitive->value.type = op_type;
|
||||
example_node->primitive->value.value = op_node;
|
||||
example_node->name = "example";
|
||||
meta_graph->nodes.emplace_back(std::move(example_node));
|
||||
|
||||
meta_graph->inputIndex = {0, 1};
|
||||
meta_graph->outputIndex = {2};
|
||||
|
||||
// input 0: data1
|
||||
auto input0 = std::make_unique<schema::TensorT>();
|
||||
input0->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input0->format = schema::Format_NHWC;
|
||||
input0->dataType = TypeId::kNumberTypeFloat32;
|
||||
input0->dims = {1, 2, 2, 3};
|
||||
input0->offset = -1;
|
||||
auto input0_data = new(std::nothrow) float[2 * 2 * 3];
|
||||
for (auto i = 0; i < 2 * 2 * 3; i++) {
|
||||
input0_data[i] = i;
|
||||
}
|
||||
input0->data.resize(sizeof(float) * 2 * 2 * 3);
|
||||
memcpy(input0->data.data(), input0_data, 2 * 2 * 3 * sizeof(float));
|
||||
delete[] input0_data;
|
||||
meta_graph->allTensors.emplace_back(std::move(input0));
|
||||
|
||||
// input 1: data2
|
||||
auto input1 = std::make_unique<schema::TensorT>();
|
||||
input1->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input1->format = schema::Format_NHWC;
|
||||
input1->dataType = TypeId::kNumberTypeFloat32;
|
||||
input1->dims = {1, 2, 2, 3};
|
||||
input1->offset = -1;
|
||||
input1->data.resize(sizeof(float) * 2 * 2 * 3);
|
||||
auto input1_data = new(std::nothrow) float[2 * 2 * 3];
|
||||
for (auto i = 0; i < 2 * 2 * 3; i++) {
|
||||
input1_data[i] = i;
|
||||
}
|
||||
memcpy(input1->data.data(), input1_data, 2 * 2 * 3 * sizeof(float));
|
||||
delete[] input1_data;
|
||||
meta_graph->allTensors.emplace_back(std::move(input1));
|
||||
|
||||
// final add output
|
||||
auto add_output = std::make_unique<schema::TensorT>();
|
||||
add_output->nodeType = schema::NodeType::NodeType_Parameter;
|
||||
add_output->format = schema::Format_NHWC;
|
||||
add_output->dataType = TypeId::kNumberTypeFloat32;
|
||||
add_output->dims = {1, 2, 2, 3};
|
||||
meta_graph->allTensors.emplace_back(std::move(add_output));
|
||||
// final output
|
||||
return meta_graph;
|
||||
}
|
||||
|
||||
MetaGraphTptr BuildGraphForOneInput(schema::PrimitiveType op_type, void *op_node) {
|
||||
auto meta_graph = std::make_shared<schema::MetaGraphT>();
|
||||
meta_graph->name = "graph";
|
||||
// biasadd node
|
||||
auto example_node = std::make_unique<schema::CNodeT>();
|
||||
example_node->inputIndex = {0};
|
||||
example_node->outputIndex = {1};
|
||||
example_node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
example_node->primitive->value.type = op_type;
|
||||
example_node->primitive->value.value = op_node;
|
||||
example_node->name = "example";
|
||||
meta_graph->nodes.emplace_back(std::move(example_node));
|
||||
|
||||
meta_graph->inputIndex = {0};
|
||||
meta_graph->outputIndex = {1};
|
||||
|
||||
// input 0: data1
|
||||
auto input0 = std::make_unique<schema::TensorT>();
|
||||
input0->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input0->format = schema::Format_NHWC;
|
||||
input0->dataType = TypeId::kNumberTypeFloat32;
|
||||
input0->dims = {1, 2, 2, 3};
|
||||
input0->offset = -1;
|
||||
auto input0_data = new(std::nothrow) float[2 * 2 * 3];
|
||||
for (auto i = 0; i < 2 * 2 * 3; i++) {
|
||||
input0_data[i] = i + 1;
|
||||
}
|
||||
input0->data.resize(sizeof(float) * 2 * 2 * 3);
|
||||
memcpy(input0->data.data(), input0_data, 2 * 2 * 3 * sizeof(float));
|
||||
delete[] input0_data;
|
||||
meta_graph->allTensors.emplace_back(std::move(input0));
|
||||
|
||||
// final add output
|
||||
auto add_output = std::make_unique<schema::TensorT>();
|
||||
add_output->nodeType = schema::NodeType::NodeType_Parameter;
|
||||
add_output->format = schema::Format_NHWC;
|
||||
add_output->dataType = TypeId::kNumberTypeFloat32;
|
||||
add_output->dims = {1, 2, 2, 3};
|
||||
meta_graph->allTensors.emplace_back(std::move(add_output));
|
||||
|
||||
// final output
|
||||
return meta_graph;
|
||||
}
|
||||
|
||||
MetaGraphTptr BuildMixGraph() {
|
||||
auto meta_graph = std::make_shared<schema::MetaGraphT>();
|
||||
meta_graph->name = "graph";
|
||||
// add node
|
||||
auto add_node = std::make_unique<schema::CNodeT>();
|
||||
add_node->inputIndex = {0, 1};
|
||||
add_node->outputIndex = {2};
|
||||
add_node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
add_node->primitive->value.type = schema::PrimitiveType_Add;
|
||||
add_node->primitive->value.value = new schema::AddT;
|
||||
add_node->name = "add";
|
||||
meta_graph->nodes.emplace_back(std::move(add_node));
|
||||
|
||||
meta_graph->inputIndex = {0, 1, 2};
|
||||
meta_graph->outputIndex = {4};
|
||||
|
||||
auto mul_node = std::make_unique<schema::CNodeT>();
|
||||
mul_node->inputIndex = {2, 3};
|
||||
mul_node->outputIndex = {4};
|
||||
mul_node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
mul_node->primitive->value.type = schema::PrimitiveType_Mul;
|
||||
mul_node->primitive->value.value = new schema::MulT;
|
||||
mul_node->name = "mul";
|
||||
meta_graph->nodes.emplace_back(std::move(mul_node));
|
||||
|
||||
// input 0: data1
|
||||
auto input0 = std::make_unique<schema::TensorT>();
|
||||
input0->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input0->format = schema::Format_NHWC;
|
||||
input0->dataType = TypeId::kNumberTypeFloat32;
|
||||
input0->dims = {1, 2, 2, 3};
|
||||
input0->offset = -1;
|
||||
auto input0_data = new(std::nothrow) float[2 * 2 * 3];
|
||||
for (auto i = 0; i < 2 * 2 * 3; i++) {
|
||||
input0_data[i] = i;
|
||||
}
|
||||
input0->data.resize(sizeof(float) * 2 * 2 * 3);
|
||||
memcpy(input0->data.data(), input0_data, 2 * 2 * 3 * sizeof(float));
|
||||
delete[] input0_data;
|
||||
meta_graph->allTensors.emplace_back(std::move(input0));
|
||||
|
||||
// input 1: data2
|
||||
auto input1 = std::make_unique<schema::TensorT>();
|
||||
input1->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input1->format = schema::Format_NHWC;
|
||||
input1->dataType = TypeId::kNumberTypeFloat32;
|
||||
input1->dims = {1, 2, 2, 3};
|
||||
input1->offset = -1;
|
||||
input1->data.resize(sizeof(float) * 2 * 2 * 3);
|
||||
auto input1_data = new(std::nothrow) float[2 * 2 * 3];
|
||||
for (auto i = 0; i < 2 * 2 * 3; i++) {
|
||||
input1_data[i] = i;
|
||||
}
|
||||
memcpy(input1->data.data(), input1_data, 2 * 2 * 3 * sizeof(float));
|
||||
delete[] input1_data;
|
||||
meta_graph->allTensors.emplace_back(std::move(input1));
|
||||
|
||||
// addoutput
|
||||
auto add_output = std::make_unique<schema::TensorT>();
|
||||
add_output->nodeType = schema::NodeType::NodeType_Parameter;
|
||||
add_output->format = schema::Format_NHWC;
|
||||
add_output->dataType = TypeId::kNumberTypeFloat32;
|
||||
add_output->dims = {1, 2, 2, 3};
|
||||
add_output->offset = -1;
|
||||
add_output->data.resize(sizeof(float) * 2 * 2 * 3);
|
||||
auto add_output_data = new(std::nothrow) float[2 * 2 * 3];
|
||||
memcpy(add_output->data.data(), add_output_data, 2 * 2 * 3 * sizeof(float));
|
||||
delete[] add_output_data;
|
||||
meta_graph->allTensors.emplace_back(std::move(add_output));
|
||||
|
||||
// input 2: data3
|
||||
auto input2 = std::make_unique<schema::TensorT>();
|
||||
input2->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input2->format = schema::Format_NHWC;
|
||||
input2->dataType = TypeId::kNumberTypeFloat32;
|
||||
input2->dims = {1, 2, 2, 3};
|
||||
input2->offset = -1;
|
||||
input2->data.resize(sizeof(float) * 2 * 2 * 3);
|
||||
auto input2_data = new(std::nothrow) float[2 * 2 * 3];
|
||||
for (auto i = 0; i < 2 * 2 * 3; i++) {
|
||||
input2_data[i] = 10;
|
||||
}
|
||||
memcpy(input2->data.data(), input2_data, 2 * 2 * 3 * sizeof(float));
|
||||
delete[] input2_data;
|
||||
meta_graph->allTensors.emplace_back(std::move(input2));
|
||||
|
||||
// final mul output
|
||||
auto mul_output = std::make_unique<schema::TensorT>();
|
||||
mul_output->nodeType = schema::NodeType::NodeType_Parameter;
|
||||
mul_output->format = schema::Format_NHWC;
|
||||
mul_output->dataType = TypeId::kNumberTypeFloat32;
|
||||
mul_output->dims = {1, 2, 2, 3};
|
||||
meta_graph->allTensors.emplace_back(std::move(mul_output));
|
||||
// final output
|
||||
return meta_graph;
|
||||
}
|
||||
} // namespace
|
||||
TEST_F(ConstantFoldingFusionTest, TestADDConstantFold) {
|
||||
auto meta_graph = BuildGraph(schema::PrimitiveType_Add, new schema::AddT);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 0);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingFusionTest, TestMixedConstantFold) {
|
||||
auto meta_graph = BuildMixGraph();
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 0);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingFusionTest, TestSubConstantFold) {
|
||||
auto meta_graph = BuildGraph(schema::PrimitiveType_Sub, new schema::SubT);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 0);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingFusionTest, TestMulConstantFold) {
|
||||
auto meta_graph = BuildGraph(schema::PrimitiveType_Mul, new schema::MulT);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 0);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingFusionTest, TestTransposeConstantFold) {
|
||||
auto transposeT = new schema::TransposeT;
|
||||
transposeT->perm = {3, 0, 1, 2};
|
||||
auto meta_graph = BuildGraph(schema::PrimitiveType_Transpose, transposeT);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 0);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingFusionTest, TestTileConstantFold) {
|
||||
auto tileT = new schema::TileT;
|
||||
tileT->multiples = {1, 2, 2, 2};
|
||||
auto meta_graph = BuildGraph(schema::PrimitiveType_Tile, tileT);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 0);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingFusionTest, TestStridedSliceConstantFold) {
|
||||
auto stridedSliceT = new schema::StridedSliceT;
|
||||
stridedSliceT->begin = {1};
|
||||
stridedSliceT->end = {3};
|
||||
stridedSliceT->stride = {1};
|
||||
auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_StridedSlice, stridedSliceT);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 0);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingFusionTest, TestStackConstantFold) {
|
||||
auto stackT = new schema::StackT;
|
||||
stackT->axis = 1;
|
||||
auto meta_graph = BuildGraph(schema::PrimitiveType_Stack, stackT);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 0);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingFusionTest, TestSliceConstantFold) {
|
||||
auto sliceT = new schema::SliceT;
|
||||
auto meta_graph = BuildGraph(schema::PrimitiveType_Slice, sliceT);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 0);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingFusionTest, TestShapeConstantFold) {
|
||||
auto shapeT = new schema::ShapeT;
|
||||
auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_Shape, shapeT);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 0);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingFusionTest, TestRsqrtConstantFold) {
|
||||
auto rsqrtT = new schema::RsqrtT;
|
||||
auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_Rsqrt, rsqrtT);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 0);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingFusionTest, TestReshapeConstantFold) {
|
||||
auto reshapeT = new schema::ReshapeT;
|
||||
reshapeT->shape = {2, 6};
|
||||
auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_Reshape, reshapeT);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 0);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingFusionTest, TestRangeConstantFold) {
|
||||
auto rangeT = new schema::RangeT;
|
||||
rangeT->limit = 10;
|
||||
rangeT->start = 1;
|
||||
rangeT->delta = 1;
|
||||
auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_Range, rangeT);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 0);
|
||||
}
|
||||
TEST_F(ConstantFoldingFusionTest, TestMatmulConstantFold) {
|
||||
auto matmulT = new schema::MatMulT;
|
||||
auto meta_graph = BuildGraph(schema::PrimitiveType_MatMul, matmulT);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 0);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingFusionTest, TestExpandDimsConstantFold) {
|
||||
auto expandDimsT = new schema::ExpandDimsT;
|
||||
auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_ExpandDims, expandDimsT);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 0);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingFusionTest, TestConcatDimsConstantFold) {
|
||||
auto concatT = new schema::ConcatT;
|
||||
auto meta_graph = BuildGraph(schema::PrimitiveType_Concat, concatT);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 0);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingFusionTest, TestCastDimsConstantFold) {
|
||||
auto castT = new schema::CastT;
|
||||
castT->srcT = kNumberTypeUInt8;
|
||||
castT->dstT = kNumberTypeFloat32;
|
||||
auto meta_graph = BuildGraphForOneInput(schema::PrimitiveType_Cast, castT);
|
||||
auto input_tensor = meta_graph->allTensors.at(0).get();
|
||||
input_tensor->dataType = kNumberTypeUInt8;
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 0);
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -81,6 +81,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
../optimizer/fusion/conv_transform_fusion.cc
|
||||
../optimizer/fusion/conv_scale_fusion.cc
|
||||
../optimizer/fusion/conv_bn_fusion.cc
|
||||
../optimizer/fusion/constant_folding_fusion.cc
|
||||
)
|
||||
|
||||
add_subdirectory(parser/caffe)
|
||||
|
|
|
@ -18,10 +18,11 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include "utils/log_adapter.h"
|
||||
#include "mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.h"
|
||||
#include "mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.h"
|
||||
#include "mindspore/lite/tools/optimizer/fusion/conv_scale_fusion.h"
|
||||
#include "mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.h"
|
||||
#include "tools/optimizer/fusion/conv_biasadd_fusion.h"
|
||||
#include "tools/optimizer/fusion/conv_activation_fusion.h"
|
||||
#include "tools/optimizer/fusion/conv_scale_fusion.h"
|
||||
#include "tools/optimizer/fusion/conv_bn_fusion.h"
|
||||
#include "tools/optimizer/fusion/constant_folding_fusion.h"
|
||||
|
||||
using std::string;
|
||||
namespace mindspore {
|
||||
|
@ -43,6 +44,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph) {
|
|||
schema::ActivationType_RELU));
|
||||
pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu6", schema::PrimitiveType_Activation,
|
||||
schema::ActivationType_RELU6));
|
||||
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(old_graph);
|
||||
return new_graph;
|
||||
|
|
|
@ -95,55 +95,6 @@ void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) {
|
|||
|
||||
int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
||||
STATUS status;
|
||||
// // constant folding
|
||||
// {
|
||||
// Optimizer topologicalSortOptimizer;
|
||||
// topologicalSortOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
// status = topologicalSortOptimizer.Run(graphDefT);
|
||||
// if (status != RET_OK) {
|
||||
// MS_LOG(ERROR)<<"Run topologicalSortOptimizer graphPasses Failed";
|
||||
// return status;
|
||||
// }
|
||||
// Optimizer constFoldOptimizer;
|
||||
// constFoldOptimizer.AddPass(new (std::nothrow) AddConstFoldPass());
|
||||
// constFoldOptimizer.AddPass(new (std::nothrow) CastConstFoldPass());
|
||||
// constFoldOptimizer.AddPass(new (std::nothrow) ConcatV2ConstFoldPass());
|
||||
// constFoldOptimizer.AddPass(new (std::nothrow) ExpandDimsConstFoldPass());
|
||||
// constFoldOptimizer.AddPass(new (std::nothrow) MulConstFoldPass());
|
||||
// constFoldOptimizer.AddPass(new (std::nothrow) RangeConstFoldPass());
|
||||
// constFoldOptimizer.AddPass(new (std::nothrow) ReshapeConstFoldPass());
|
||||
// constFoldOptimizer.AddPass(new (std::nothrow) RsqrtConstFoldPass());
|
||||
// constFoldOptimizer.AddPass(new (std::nothrow) ShapeConstFoldPass());
|
||||
// constFoldOptimizer.AddPass(new (std::nothrow) SliceConstFoldPass());
|
||||
// constFoldOptimizer.AddPass(new (std::nothrow) StackConstFoldPass());
|
||||
// constFoldOptimizer.AddPass(new (std::nothrow) StridedSliceConstFoldPass());
|
||||
// constFoldOptimizer.AddPass(new (std::nothrow) SubConstFoldPass());
|
||||
// constFoldOptimizer.AddPass(new (std::nothrow) TileConstFoldPass());
|
||||
// constFoldOptimizer.AddPass(new (std::nothrow) TransposeConstFoldPass());
|
||||
// constFoldOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
// status = constFoldOptimizer.Run(graphDefT);
|
||||
// if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
// MS_LOG(ERROR) << "Run constFoldOptimizer graphPasses Failed";
|
||||
// return status;
|
||||
// }
|
||||
// }
|
||||
|
||||
// fusion
|
||||
// {
|
||||
// Optimizer fusionOptimizer;
|
||||
// fusionOptimizer.AddPass(new (std::nothrow) ConvBiasAddFusionPass());
|
||||
// fusionOptimizer.AddPass(new (std::nothrow) ConvBNFusionPass());
|
||||
// fusionOptimizer.AddPass(new (std::nothrow) ConvScaleFusionPass());
|
||||
// fusionOptimizer.AddPass(new (std::nothrow) ConvReluFusionPass());
|
||||
// fusionOptimizer.AddPass(new (std::nothrow) ConvRelu6FusionPass());
|
||||
// fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
// status = fusionOptimizer.Run(graphDefT);
|
||||
// if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
// MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed";
|
||||
// return status;
|
||||
// }
|
||||
// }
|
||||
|
||||
// weight format trans
|
||||
if (ctx.formatTrans) {
|
||||
Optimizer weightFormatOptimizer;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include <vector>
|
||||
#include "src/ir/primitive_t_value.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -314,6 +315,11 @@ schema::PrimitiveType GetCNodeType(const BaseRef &n) {
|
|||
auto primitive = value->cast<PrimitiveTValuePtr>();
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
return primitive->GetPrimitiveT()->value.type;
|
||||
} else if (utils::isa<Primitive>(value)) {
|
||||
auto primitive = value->cast<PrimitivePtr>();
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
MS_LOG(INFO) << "anf primitive node type:" << primitive->name();
|
||||
return schema::PrimitiveType_NONE;
|
||||
}
|
||||
return schema::PrimitiveType_NONE;
|
||||
}
|
||||
|
@ -329,5 +335,37 @@ bool IsConvNode(const BaseRef &n) {
|
|||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool CheckIsAllInputsParam(const AnfNodePtr &node) {
|
||||
if (utils::isa<CNode>(node)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
for (auto i = 1; i < cnode->inputs().size(); i++) {
|
||||
if (!utils::isa<Parameter>(cnode->input(i))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t GetOutputTensorNum(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto type = node->Type();
|
||||
if (type == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
if (type->isa<Tuple>()) {
|
||||
auto tuple_type = type->cast<TuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_type);
|
||||
return tuple_type->size();
|
||||
} else if (type->isa<TensorType>() || type->isa<Number>()) {
|
||||
return 1;
|
||||
} else if (type->isa<TypeNone>()) {
|
||||
return 0;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,13 +29,6 @@
|
|||
using PrimitiveTValuePtr = std::shared_ptr<mindspore::lite::PrimitiveTValue>;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
bool AnfEqual(const BaseRef &a, const BaseRef &b);
|
||||
|
||||
bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b);
|
||||
|
||||
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
|
||||
bool multigraph = false);
|
||||
|
||||
bool IsRealCNodeKernel(const AnfNodePtr &node);
|
||||
|
||||
bool IsGraphKernel(const AnfNodePtr &node);
|
||||
|
@ -62,6 +55,10 @@ schema::PrimitiveType GetCNodeType(const BaseRef &node);
|
|||
bool IsParamNode(const BaseRef &n);
|
||||
|
||||
bool IsConvNode(const BaseRef &n);
|
||||
|
||||
bool CheckIsAllInputsParam(const AnfNodePtr &node);
|
||||
|
||||
size_t GetOutputTensorNum(const AnfNodePtr &node);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_
|
||||
|
|
|
@ -1,90 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_OPTIMIZER_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_COMMON_OPTIMIZER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "src/common/utils.h"
|
||||
|
||||
#include "backend/optimizer/common/pass_manager.h"
|
||||
#include "backend/optimizer/common/pattern_engine.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
using PatternListType = std::initializer_list<BaseRef>;
|
||||
|
||||
class PatternProcessPass : public NodePass {
|
||||
public:
|
||||
explicit PatternProcessPass(const std::string &name = "", bool multigraph = true);
|
||||
~PatternProcessPass() override = default;
|
||||
virtual const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const = 0;
|
||||
virtual const BaseRef DefinePattern() const;
|
||||
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
|
||||
|
||||
private:
|
||||
void Build();
|
||||
|
||||
AnfNodePtr pattern_ = nullptr;
|
||||
bool multigraph_ = true;
|
||||
PatternEngine pattern_engine_;
|
||||
PrimitiveVarMapPtr primitive_vars_;
|
||||
};
|
||||
|
||||
class MultipleOutputPatternProcessPass : public PatternProcessPass {
|
||||
public:
|
||||
explicit MultipleOutputPatternProcessPass(const std::string &name = "", bool multigraph = true)
|
||||
: PatternProcessPass(name, multigraph),
|
||||
child_pattern_engine_(PatternEngine(std::make_shared<DefaultVisitor>(),
|
||||
std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual),
|
||||
std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))),
|
||||
child_primitive_vars_(std::make_shared<PrimitiveVarMap>()) {}
|
||||
~MultipleOutputPatternProcessPass() override = default;
|
||||
virtual BaseRef DefineAnotherPattern() const = 0;
|
||||
// check two patterns whether share the same nodes or not
|
||||
virtual bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const = 0;
|
||||
|
||||
protected:
|
||||
bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const;
|
||||
PatternEngine child_pattern_engine_;
|
||||
PrimitiveVarMapPtr child_primitive_vars_;
|
||||
};
|
||||
|
||||
class GraphOptimizer {
|
||||
public:
|
||||
explicit GraphOptimizer(const std::string &name = "graph_optimizer") : name_(name) {}
|
||||
virtual ~GraphOptimizer() = default;
|
||||
|
||||
void AddPassManager(const PassManagerPtr &pass_manager);
|
||||
FuncGraphPtr Optimize(const FuncGraphPtr &func_graph, bool run_only_once = true);
|
||||
|
||||
private:
|
||||
const std::string name_ = "graph_optimizer";
|
||||
std::vector<PassManagerPtr> pass_managers_{};
|
||||
bool run_only_once_ = true;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_OPTIMIZER_H_
|
||||
|
|
@ -0,0 +1,169 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/optimizer/fusion/constant_folding_fusion.h"
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "src/kernel_factory.h"
|
||||
#include "src/common/anf_exporter/anf_exporter.h"
|
||||
#include "src/scheduler.h"
|
||||
#include "include/context.h"
|
||||
#include "src/lite_session.h"
|
||||
#include "src/ir/primitive_t_value.h"
|
||||
#include "src/populate_parameter.h"
|
||||
|
||||
using mindspore::lite::KernelFactory;
|
||||
using mindspore::lite::tensor::Tensor;
|
||||
using mindspore::lite::PrimitiveTValue;
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
const std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) {
|
||||
MS_ASSERT(CNode != nullptr);
|
||||
auto tmp_meta_graph = std::make_unique<schema::MetaGraphT>();
|
||||
auto tmp_fb_node = std::make_unique<schema::CNodeT>();
|
||||
lite::AnfExporter anfExporter;
|
||||
anfExporter.SetOpInputNode(CNode, tmp_meta_graph.get(), tmp_fb_node.get());
|
||||
std::vector<Tensor *> input_tensors;
|
||||
for (auto input_index : tmp_fb_node->inputIndex) {
|
||||
auto tensorT = tmp_meta_graph->allTensors.at(input_index).get();
|
||||
auto tensor_shape = tensorT->dims;
|
||||
auto lite_tensor =
|
||||
new(std::nothrow)Tensor(TypeId(tensorT->dataType), tensor_shape, tensorT->format, tensorT->nodeType);
|
||||
auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t);
|
||||
// when tensorT as graph input
|
||||
if (lite_tensor_size == 0) {
|
||||
return input_tensors;
|
||||
}
|
||||
auto tensor_data = new(std::nothrow)char[lite_tensor_size / sizeof(char)];
|
||||
auto ret = memcpy_s(tensor_data, lite_tensor_size, tensorT->data.data(), lite_tensor_size);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "memcpy error: " << ret;
|
||||
}
|
||||
lite_tensor->SetData(tensor_data);
|
||||
input_tensors.emplace_back(lite_tensor);
|
||||
}
|
||||
return input_tensors;
|
||||
}
|
||||
schema::Primitive *PackPrimitiveT(const CNodePtr &cnode) {
|
||||
auto primitiveT_value =
|
||||
GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
|
||||
if (primitiveT_value == nullptr) {
|
||||
MS_LOG(ERROR) << "PrimitiveT_value is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto *lite_primitive = primitiveT_value->GetPrimitiveT();
|
||||
if (lite_primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "Primitive in primitiveT_value is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
auto offset = schema::Primitive::Pack(builder, lite_primitive);
|
||||
builder.Finish(offset);
|
||||
auto buf = builder.GetBufferPointer();
|
||||
auto primitive = flatbuffers::GetRoot<schema::Primitive>(buf);
|
||||
return const_cast<schema::Primitive *>(primitive);
|
||||
}
|
||||
const ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
|
||||
auto parameter = func_graph->add_parameter();
|
||||
std::vector<int> shape;
|
||||
std::copy(tensor->shape().begin(), tensor->shape().end(), std::back_inserter(shape));
|
||||
auto type_id = static_cast<TypeId>(tensor->data_type());
|
||||
auto type_ptr = TypeIdToType(type_id);
|
||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
|
||||
parameter->set_abstract(abstract_tensor);
|
||||
|
||||
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
|
||||
MS_ASSERT(param_value != nullptr);
|
||||
param_value->set_tensor_shape(shape);
|
||||
param_value->set_tensor_type(type_id);
|
||||
if (tensor->Data() != nullptr) {
|
||||
auto size = tensor->ElementsNum();
|
||||
auto tensor_data = new (std::nothrow) float[size];
|
||||
auto ret = memcpy_s(tensor_data, size * sizeof(float), tensor->Data(), size * sizeof(float));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "memcpy error: " << ret;
|
||||
}
|
||||
param_value->set_tensor_addr(tensor_data);
|
||||
param_value->set_tensor_size(size * sizeof(float) / sizeof(uint8_t));
|
||||
}
|
||||
parameter->set_default_param(param_value);
|
||||
return parameter;
|
||||
}
|
||||
kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs,
|
||||
lite::Primitive *primitive) {
|
||||
MS_ASSERT(nullptr != lite_primitive);
|
||||
auto data_type = inputs.front()->data_type();
|
||||
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, primitive->Type()};
|
||||
lite::Context context;
|
||||
auto parameter = kernel::PopulateParameter(primitive);
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR)
|
||||
<< "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(primitive->Type());
|
||||
return nullptr;
|
||||
}
|
||||
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
|
||||
if (creator != nullptr) {
|
||||
auto lite_kernel = creator(inputs, outputs, parameter, &context, desc, primitive);
|
||||
return lite_kernel;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
CheckIfFuncGraphIsNull(func_graph);
|
||||
CheckIfAnfNodeIsNull(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
return node;
|
||||
}
|
||||
auto any_node = node->cast<CNodePtr>();
|
||||
CheckIfCNodeIsNull(any_node);
|
||||
for (size_t i = 1; i < any_node->inputs().size(); i++) {
|
||||
auto input_node = any_node->input(i);
|
||||
if (input_node->isa<CNode>() && CheckIsAllInputsParam(input_node)) {
|
||||
auto input_cnode = input_node->cast<CNodePtr>();
|
||||
auto input_tensors = GetCNodeInputTensors(input_cnode);
|
||||
if (input_tensors.empty() || input_tensors.size() != input_cnode->inputs().size() - 1) {
|
||||
return any_node;
|
||||
}
|
||||
MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope();
|
||||
auto output_nums = GetOutputTensorNum(input_cnode);
|
||||
std::vector<Tensor *> output_tensors{output_nums, new Tensor()};
|
||||
auto scheam_primitive = PackPrimitiveT(input_cnode);
|
||||
auto lite_primitive = lite::Primitive::CreatePrimitive(scheam_primitive);
|
||||
lite_primitive->InferShape(input_tensors, output_tensors);
|
||||
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, lite_primitive);
|
||||
if (lite_kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
|
||||
return any_node;
|
||||
}
|
||||
auto ret = lite_kernel->Run();
|
||||
if (0 != ret) {
|
||||
MS_LOG(EXCEPTION) << "run kernel failed, name: " << lite_kernel->name();
|
||||
}
|
||||
auto new_parameter = CreateNewParamter(func_graph, output_tensors.front());
|
||||
any_node->set_input(i, new_parameter);
|
||||
}
|
||||
}
|
||||
return any_node;
|
||||
}
|
||||
} // namespace mindspore::opt
|
|
@ -13,29 +13,21 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_PASS_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_COMMON_PASS_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
// @brief ANF Graph level optimization base pass
|
||||
class Pass {
|
||||
class ConstFoldPass : public PatternProcessPass {
|
||||
public:
|
||||
explicit Pass(const std::string &name = "pass") : name_(name) {}
|
||||
virtual ~Pass() = default;
|
||||
virtual bool Run(const FuncGraphPtr &func_graph) = 0;
|
||||
virtual std::string name() const { return name_; }
|
||||
|
||||
private:
|
||||
const std::string name_;
|
||||
explicit ConstFoldPass(bool multigraph = true) : PatternProcessPass("constfold_pass", multigraph) {}
|
||||
~ConstFoldPass() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
using PassPtr = std::shared_ptr<Pass>;
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_PASS_H_
|
|
@ -18,7 +18,8 @@
|
|||
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_
|
||||
|
||||
#include <string>
|
||||
#include "tools/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_
|
||||
|
||||
#include "tools/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_
|
||||
|
||||
#include <string>
|
||||
#include "tools/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
class ConvTransformFusion : public PatternProcessPass {
|
||||
|
|
Loading…
Reference in New Issue