add matmul eltwise buffer fusion pass

This commit is contained in:
etone-chan 2020-05-14 10:29:10 +08:00
parent a2a3b1c6c5
commit c4a5bfb787
5 changed files with 124 additions and 17 deletions

View File

@ -580,27 +580,43 @@ void BufferFusion::MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::
}
}
void BufferFusion::MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input,
const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(candidate_fusion);
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<int> output_used_num{SizeToInt(manager->node_users()[relu_input].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu_input);
std::unordered_set<AnfNodePtr> record{cnode, relu_input};
candidate_fusion->push_back(record);
SetRecordFusionId(record);
}
void BufferFusion::MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator.HasFusionIdAttr(node)) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator.HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetCNodeName(cnode) == kBNTrainingReduceOpName) {
MatchConvBnreduce(cnode, kernel_graph, candidate_fusion);
} else if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName ||
AnfAlgo::GetCNodeName(cnode) == prim::kPrimRelu->name()) {
auto relu_input = cnode->input(1);
if (relu_input->isa<CNode>() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTensorAdd->name()) {
MatchBnupdateAddRelu(cnode, relu_input, kernel_graph, candidate_fusion);
} else if (relu_input->isa<CNode>() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTupleGetItem->name()) {
MatchBnupdateRelu(cnode, relu_input, kernel_graph, candidate_fusion);
} else if (relu_input->isa<CNode>() &&
AnfAlgo::GetCNodeName(relu_input) == prim::kPrimDepthwiseConv2dNative->name()) {
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true);
} else if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
auto eltwise_input = cnode->input(1);
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimMatMul)) {
MatchMatmulEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion);
}
if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimRelu)) {
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) {
MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
} else if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) {
MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
}
}
} else if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimDepthwiseConv2dNative->name()) {
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, false);

View File

@ -53,6 +53,8 @@ class BufferFusion : public Pass {
const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
void MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion, bool is_order);
void MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion);
void MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);

View File

@ -33,6 +33,7 @@ relu_op_info = TBERegOp("ReLU") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \
.get_op_info()

View File

@ -26,15 +26,15 @@
namespace mindspore {
namespace opt {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
class TestHWBufferFusionPy : public BackendCommon {
class TestHWBufferFusion : public BackendCommon {
public:
TestHWBufferFusionPy() : get_py_fun_("gtest_input.pre_activate.buffer_fusion_test", true) {}
~TestHWBufferFusionPy() override = default;
TestHWBufferFusion() : get_py_fun_("gtest_input.pre_activate.buffer_fusion_test", true) {}
~TestHWBufferFusion() override = default;
UT::PyFuncGraphFetcher get_py_fun_;
};
TEST_F(TestHWBufferFusionPy, test_tbe_eltwise_fusion_1) {
TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_1) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_eltwise_fusion_1", "before");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
@ -90,7 +90,7 @@ TEST_F(TestHWBufferFusionPy, test_tbe_eltwise_fusion_1) {
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWBufferFusionPy, test_tbe_eltwise_fusion_2) {
TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_2) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_eltwise_fusion_2", "before");
std::vector<int> shp{32, 10};
std::vector<int> shp_bias{10};
@ -179,7 +179,7 @@ TEST_F(TestHWBufferFusionPy, test_tbe_eltwise_fusion_2) {
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWBufferFusionPy, test_tbe_reduce_eltwise_fusion) {
TEST_F(TestHWBufferFusion, test_tbe_reduce_eltwise_fusion) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_reduce_eltwise_fusion", "before");
std::vector<int> shp{32, 10};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
@ -265,5 +265,71 @@ TEST_F(TestHWBufferFusionPy, test_tbe_reduce_eltwise_fusion) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_tbe_reduce_eltwise_fusion", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWBufferFusion, test_tbe_matmul_eltwise_fusion) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_matmul_eltwise_fusion", "before");
std::vector<int> x_shp{2048, 768};
std::vector<int> y_shp{768, 768};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, x_shp);
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, y_shp);
AbstractBasePtrList args_spec_list{x_abstract, y_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
auto ret = kg->get_return();
EXPECT_NE(ret->input(1), nullptr);
auto tuple = ret->input(1);
EXPECT_NE(tuple, nullptr);
auto cast = tuple->cast<CNodePtr>()->input(1);
EXPECT_NE(cast, nullptr);
auto relu = cast->cast<CNodePtr>()->input(1);
EXPECT_NE(relu, nullptr);
auto matmul = relu->cast<CNodePtr>()->input(1);
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({"NC1HWC0"});
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat32->type_id()});
builder.SetOutputsDeviceType({kFloat32->type_id()});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::TBE_KERNEL);
relu->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), relu.get());
KernelBuildInfoBuilder builder2;
builder2.SetInputsFormat({"NC1HWC0", "NC1HWC0"});
builder2.SetOutputsFormat({"NC1HWC0"});
builder2.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()});
builder2.SetOutputsDeviceType({kFloat32->type_id()});
builder2.SetKernelType(KernelType::TBE_KERNEL);
builder2.SetFusionType(kernel::FusionType::OPAQUE);
builder2.SetProcessor(kernel::Processor::AICORE);
builder2.SetKernelType(KernelType::TBE_KERNEL);
matmul->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder2.Build(), matmul.get());
KernelBuildInfoBuilder builder1;
builder1.SetInputsFormat({"NC1HWC0"});
builder1.SetOutputsFormat({"NC1HWC0"});
builder1.SetInputsDeviceType({kFloat32->type_id()});
builder1.SetOutputsDeviceType({kFloat16->type_id()});
builder1.SetKernelType(KernelType::TBE_KERNEL);
builder1.SetFusionType(kernel::FusionType::OPAQUE);
builder1.SetProcessor(kernel::Processor::AICORE);
builder1.SetKernelType(KernelType::TBE_KERNEL);
cast->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), cast.get());
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto buffer_fusion_pass = std::make_shared<opt::BufferFusion>();
pm->AddPass(buffer_fusion_pass);
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_tbe_matmul_eltwise_fusion", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
} // namespace opt
} // namespace mindspore

View File

@ -24,10 +24,12 @@ Reduce = P.ReduceOp()
Biasadd = P.BiasAdd()
Biasaddgrad = G.BiasAddGrad()
Cast = P.Cast()
MatMul = P.MatMul()
Fusion_relu_relu = Primitive('FusionOp_ReLU_ReLU')
Fusion_biasadd = Primitive('FusionOp_ReLU_ReLU_ReLU_BiasAdd_ReLU_ReLU_ReLU')
Fusion_biasaddgrad = Primitive('FusionOp_ReLU_ReLU_ReLU_BiasAddGrad_ReLU_ReLU_ReLU')
Fusion_matmul_relu = Primitive('FusionOp_MatMul_ReLU')
Add = P.TensorAdd()
Sub = P.Sub()
@ -133,3 +135,23 @@ def test_conv_singlein_fusion(tag):
return tuple
return fns[tag]
def test_tbe_matmul_eltwise_fusion(tag):
fns = FnDict()
@fns
def before(x, y):
matmul = MatMul(x, y)
relu = Relu(matmul)
res = Cast(relu, mstype.float16)
return res
@fns
def after(x, y):
fusion = Fusion_matmul_relu(x, y)
res = Cast(fusion)
tuple = make_tuple(res)
return tuple
return fns[tag]