!26610 add fusion_inout_test

Merge pull request !26610 from hangq/earth2
This commit is contained in:
i-robot 2021-11-26 03:42:23 +00:00 committed by Gitee
commit 8178fd5b7d
9 changed files with 458 additions and 1 deletions

View File

@ -29,8 +29,8 @@ void Conv2DFusion::Init(int64_t in_channel, int64_t out_channel, const std::vect
this->set_out_channel(out_channel);
this->set_kernel_size(kernel_size);
this->set_mode(mode);
this->set_pad_mode(pad_mode);
this->set_pad(pad);
this->set_pad_mode(pad_mode);
this->set_stride(stride);
this->set_dilation(dilation);
this->set_group(group);

View File

@ -112,6 +112,10 @@ if(MSLITE_ENABLE_CONVERTER)
${TEST_DIR}/st/graph_test.cc
${TEST_DIR}/st/sub_graph_test.cc
${TEST_DIR}/ut/src/dynamic_library_loader_test.cc
${TEST_DIR}/ut/tools/optimizer/fusion/fusion_inout_test/fusion_inout_test.cc
${TEST_DIR}/ut/tools/optimizer/fusion/fusion_inout_test/conv_fusion_inout_test.cc
${TEST_DIR}/ut/tools/optimizer/fusion/fusion_inout_test/conv_act_fusion_inout_test.cc
${TEST_DIR}/ut/tools/optimizer/fusion/fusion_inout_test/conv_bias_fusion_inout_test.cc
${TEST_DIR}/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc
${TEST_DIR}/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc
${TEST_DIR}/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc

View File

@ -32,6 +32,8 @@ echo 'run common ut tests'
# test cases of Converter
## ./lite-test --gtest_filter="TestTfliteParser*"
./lite-test --gtest_filter="ConvActFusionInoutTest*"
./lite-test --gtest_filter="ConvBiasFusionInoutTest*"
# test cases of framework

View File

@ -0,0 +1,72 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "tools/optimizer/fusion/conv_activation_fusion.h"
#include "test/ut/tools/optimizer/fusion/fusion_inout_test/conv_fusion_inout_test.h"
#include "backend/kernel_compiler/cpu/nnacl/op_base.h"
#include "ops/fusion/activation.h"
namespace mindspore {
class ConvActFusionInoutTest : public ConvFusionInoutTest {
public:
ConvActFusionInoutTest() = default;
protected:
void InitPass() override { this->pass_ = std::make_shared<opt::ConvActivationFusion>(); }
void InitGraph() override {
this->graph_ = std::make_shared<FuncGraph>();
MS_CHECK_TRUE_MSG(graph_ != nullptr, , "Create FuncGraph failed");
auto input = AddParameter(graph_, 0, {1, ih_, iw_, ic_}, kNumberTypeFloat32, "graph_input");
if (input == nullptr) {
this->graph_ = nullptr;
return;
}
auto conv = AddConv(graph_, input, "conv");
if (conv == nullptr) {
this->graph_ = nullptr;
return;
}
auto act = AddAct(graph_, conv, "conv_act");
if (act == nullptr) {
this->graph_ = nullptr;
return;
}
auto ret = AddReturn(graph_, {act});
if (ret == nullptr) {
this->graph_ = nullptr;
return;
}
}
private:
static CNodePtr AddAct(const FuncGraphPtr &graph, const AnfNodePtr &input, const std::string &name) {
auto prim = std::make_unique<ops::Activation>();
MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "create Act primitivec failed");
prim->Init();
prim->set_activation_type(ActivationType::RELU);
auto act_primitive = NewValueNode(std::shared_ptr<ops::PrimitiveC>(prim.release()));
MS_CHECK_TRUE_RET(act_primitive != nullptr, nullptr);
auto act = graph->NewCNode({act_primitive, input});
MS_CHECK_TRUE_MSG(act != nullptr, nullptr, "create Act failed");
act->set_fullname_with_scope(name);
return act;
}
};
TEST_F(ConvActFusionInoutTest, test) { ASSERT_EQ(DoTest(), true); }
} // namespace mindspore

View File

@ -0,0 +1,72 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "tools/optimizer/fusion/conv_biasadd_fusion.h"
#include "test/ut/tools/optimizer/fusion/fusion_inout_test/conv_fusion_inout_test.h"
#include "backend/kernel_compiler/cpu/nnacl/op_base.h"
#include "ops/bias_add.h"
namespace mindspore {
class ConvBiasFusionInoutTest : public ConvFusionInoutTest {
public:
ConvBiasFusionInoutTest() = default;
protected:
void InitPass() override { this->pass_ = std::make_shared<opt::ConvBiasaddFusion>(); }
void InitGraph() override {
this->graph_ = std::make_shared<FuncGraph>();
MS_CHECK_TRUE_MSG(graph_ != nullptr, , "Create FuncGraph failed");
auto input = AddParameter(graph_, 0, {1, ih_, iw_, ic_}, kNumberTypeFloat32, "graph_input");
if (input == nullptr) {
this->graph_ = nullptr;
return;
}
auto conv = AddConv(graph_, input, "conv");
if (conv == nullptr) {
this->graph_ = nullptr;
return;
}
auto bias = AddBias(graph_, conv, "conv_bias");
if (bias == nullptr) {
this->graph_ = nullptr;
return;
}
auto ret = AddReturn(graph_, {bias});
if (ret == nullptr) {
this->graph_ = nullptr;
return;
}
}
private:
static CNodePtr AddBias(const FuncGraphPtr &graph, const AnfNodePtr &input, const std::string &name) {
auto prim = std::make_unique<ops::BiasAdd>();
MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "create BiasAdd primitivec failed");
prim->Init();
auto bias_primitive = NewValueNode(std::shared_ptr<ops::PrimitiveC>(prim.release()));
MS_CHECK_TRUE_RET(bias_primitive != nullptr, nullptr);
auto bias = AddParameter(graph, oc_, {oc_}, kNumberTypeFloat32, name + "_bias");
auto bias_add = graph->NewCNode({bias_primitive, input, bias});
MS_CHECK_TRUE_MSG(bias_add != nullptr, nullptr, "create BiasAdd failed");
bias_add->set_fullname_with_scope(name);
return bias_add;
}
};
TEST_F(ConvBiasFusionInoutTest, test) { ASSERT_EQ(DoTest(), true); }
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "test/ut/tools/optimizer/fusion/fusion_inout_test/conv_fusion_inout_test.h"
#include <memory>
#include "src/common/log_adapter.h"
#include "ir/func_graph.h"
#include "ops/fusion/conv2d_fusion.h"
#include "backend/kernel_compiler/cpu/nnacl/op_base.h"
namespace mindspore {
ValueNodePtr ConvFusionInoutTest::CreateConvPrimitiveValue() {
auto prim = std::make_unique<ops::Conv2DFusion>();
MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "create Conv2d primitivec failed");
prim->Init(ic_, oc_, {kh_, kw_});
prim->set_pad_mode(PadMode::SAME);
return NewValueNode(std::shared_ptr<ops::PrimitiveC>(prim.release()));
}
CNodePtr ConvFusionInoutTest::AddConv(const FuncGraphPtr &graph, const AnfNodePtr &input, const std::string &name) {
auto conv_primitive = CreateConvPrimitiveValue();
MS_CHECK_TRUE_RET(conv_primitive != nullptr, nullptr);
auto weight = AddParameter(graph, ic_ * oc_ * kh_ * kw_, {oc_, kh_, kw_, ic_}, kNumberTypeFloat32, name + "_weight");
auto bias = AddParameter(graph, oc_, {oc_}, kNumberTypeFloat32, name + "_bias");
auto conv = graph->NewCNode({conv_primitive, input, weight, bias});
MS_CHECK_TRUE_MSG(conv != nullptr, nullptr, "create Conv2d failed");
conv->set_fullname_with_scope(name);
return conv;
}
} // namespace mindspore

View File

@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TEST_UT_TOOLS_OPTIMIZER_FUSION_FUSION_INOUT_TEST_CONV_FUSION_INOUT_TEST_H_
#define MINDSPORE_LITE_TEST_UT_TOOLS_OPTIMIZER_FUSION_FUSION_INOUT_TEST_CONV_FUSION_INOUT_TEST_H_
#include <string>
#include "test/ut/tools/optimizer/fusion/fusion_inout_test/fusion_inout_test.h"
#include "ir/anf.h"
#include "backend/optimizer/common/pass.h"
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/pass_manager.h"
namespace mindspore {
class ConvFusionInoutTest : public FusionInoutTest {
public:
ConvFusionInoutTest() = default;
protected:
static ValueNodePtr CreateConvPrimitiveValue();
static CNodePtr AddConv(const FuncGraphPtr &graph, const AnfNodePtr &input, const std::string &name);
protected:
static const int ic_ = 16;
static const int oc_ = 16;
static const int kh_ = 3;
static const int kw_ = 3;
static const int ih_ = 16;
static const int iw_ = 16;
};
} // namespace mindspore
#endif

View File

@ -0,0 +1,162 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "test/ut/tools/optimizer/fusion/fusion_inout_test/fusion_inout_test.h"
#include <memory>
#include "src/common/log_adapter.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/ops/ops_def.h"
#include "ir/func_graph.h"
#include "ops/fusion/conv2d_fusion.h"
#include "backend/kernel_compiler/cpu/nnacl/op_base.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore {
FuncGraphPtr FusionInoutTest::Fuse() {
if (graph_ == nullptr) {
MS_LOG(WARNING) << "Graph not inited";
return nullptr;
}
if (pass_ == nullptr) {
MS_LOG(WARNING) << "Pass not inited";
return graph_;
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
MS_CHECK_TRUE_MSG(optimizer != nullptr, nullptr, "Create GraphOptimizer failed");
auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);
MS_CHECK_TRUE_MSG(fusion_pm != nullptr, nullptr, "Create PassManager failed");
fusion_pm->AddPass(pass_);
optimizer->AddPassManager(fusion_pm);
if (optimizer->Optimize(graph_) == nullptr) {
MS_LOG(ERROR) << "run op fusion failed.";
return nullptr;
}
return graph_;
}
ParameterPtr FusionInoutTest::AddParameter(const FuncGraphPtr &graph, size_t data_size,
const std::vector<int64_t> &shape, TypeId data_type,
const std::string &name) {
MS_ASSERT(graph != nullptr);
auto parameter = graph->add_parameter();
if (parameter == nullptr) {
MS_LOG(ERROR) << "CreateParameter failed";
return nullptr;
}
void *data = nullptr;
if (data_size > 0) {
data = malloc(data_size);
if (data == nullptr) {
MS_LOG(ERROR) << "Malloc tensor data failed";
return nullptr;
}
}
auto tensor_info = lite::CreateTensorInfo(data, data_size, shape, data_type);
free(data);
data = nullptr;
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "CreateTensorInfo failed";
return nullptr;
}
auto abstract_tensor = tensor_info->ToAbstract();
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "CreateTensorAbstract failed";
return nullptr;
}
parameter->set_abstract(abstract_tensor);
if (data_size > 0) {
parameter->set_default_param(tensor_info);
}
parameter->set_name(name);
return parameter;
}
CNodePtr FusionInoutTest::AddReturn(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &return_inputs) {
if (return_inputs.empty()) {
return nullptr;
}
AnfNodePtr return_input = nullptr;
if (return_inputs.size() == 1) {
return_input = return_inputs.front();
} else {
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>();
if (make_tuple_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new MakeTuple failed";
return nullptr;
}
auto return_input_cnode = graph->NewCNode(make_tuple_prim_ptr, return_inputs);
if (return_input_cnode == nullptr) {
MS_LOG(ERROR) << "new make tuple cnode failed";
return nullptr;
}
return_input_cnode->set_fullname_with_scope("return tuple");
return_input = return_input_cnode;
}
auto return_prim = std::make_shared<lite::Return>();
MS_CHECK_TRUE_MSG(return_prim != nullptr, nullptr, "create return primitivec failed");
auto return_cnode = graph->NewCNode(return_prim, {return_input});
MS_CHECK_TRUE_MSG(return_cnode != nullptr, nullptr, "create Return failed");
return_cnode->set_fullname_with_scope("Return");
graph->set_return(return_cnode);
return return_cnode;
}
std::vector<std::string> FusionInoutTest::GetInputNames() {
if (graph_ == nullptr) {
return {};
}
auto inputs = graph_->get_inputs();
std::vector<std::string> ret(inputs.size());
std::transform(inputs.begin(), inputs.end(), ret.begin(),
[](const AnfNodePtr &node) { return node->fullname_with_scope(); });
return ret;
}
size_t FusionInoutTest::GetOutputNumber() {
if (graph_ == nullptr) {
return 0;
}
auto ret = graph_->get_return();
auto ret_input = ret->input(1);
if (utils::isa<CNodePtr>(ret_input)) {
auto ret_input_cnode = utils::cast<CNodePtr>(ret_input);
if (!opt::CheckPrimitiveType(ret_input_cnode, prim::kPrimMakeTuple)) {
return 1;
} else {
return ret_input_cnode->inputs().size() - 1;
}
} else {
return 1;
}
}
bool FusionInoutTest::DoTest() {
InitPass();
InitGraph();
auto old_inputs = GetInputNames();
auto old_outputs_num = GetOutputNumber();
auto ret_graph = Fuse();
if (ret_graph == nullptr) {
MS_LOG(ERROR) << "Fusion failed";
return false;
}
auto new_inputs = GetInputNames();
auto new_outputs_num = GetOutputNumber();
return old_inputs == new_inputs && old_outputs_num == new_outputs_num;
}
} // namespace mindspore

View File

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TEST_UT_TOOLS_OPTIMIZER_FUSION_FUSION_INOUT_TEST_FUSION_INOUT_TEST_H_
#define MINDSPORE_LITE_TEST_UT_TOOLS_OPTIMIZER_FUSION_FUSION_INOUT_TEST_FUSION_INOUT_TEST_H_
#include <string>
#include <vector>
#include "common/common_test.h"
#include "ir/anf.h"
#include "backend/optimizer/common/pass.h"
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/pass_manager.h"
namespace mindspore {
class FusionInoutTest : public mindspore::CommonTest {
public:
FusionInoutTest() = default;
bool DoTest();
protected:
FuncGraphPtr Fuse();
std::vector<std::string> GetInputNames();
size_t GetOutputNumber();
virtual void InitPass() = 0;
virtual void InitGraph() = 0;
static ParameterPtr AddParameter(const FuncGraphPtr &graph, size_t data_size, const std::vector<int64_t> &shape,
TypeId data_type, const std::string &name);
static CNodePtr AddReturn(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &return_inputs);
protected:
opt::PassPtr pass_ = nullptr;
FuncGraphPtr graph_ = nullptr;
};
} // namespace mindspore
#endif