forked from mindspore-Ecosystem/mindspore
!1146 add matmul eltwise buffer fusion
Merge pull request !1146 from Etone.Chan/master
This commit is contained in:
commit
5075f0a27e
|
@ -580,27 +580,43 @@ void BufferFusion::MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::
|
|||
}
|
||||
}
|
||||
|
||||
void BufferFusion::MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input,
|
||||
const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<int> output_used_num{SizeToInt(manager->node_users()[relu_input].size())};
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu_input);
|
||||
std::unordered_set<AnfNodePtr> record{cnode, relu_input};
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
}
|
||||
|
||||
void BufferFusion::MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator.HasFusionIdAttr(node)) {
|
||||
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator.HasFusionIdAttr(node) ||
|
||||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetCNodeName(cnode) == kBNTrainingReduceOpName) {
|
||||
MatchConvBnreduce(cnode, kernel_graph, candidate_fusion);
|
||||
} else if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName ||
|
||||
AnfAlgo::GetCNodeName(cnode) == prim::kPrimRelu->name()) {
|
||||
auto relu_input = cnode->input(1);
|
||||
if (relu_input->isa<CNode>() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTensorAdd->name()) {
|
||||
MatchBnupdateAddRelu(cnode, relu_input, kernel_graph, candidate_fusion);
|
||||
} else if (relu_input->isa<CNode>() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTupleGetItem->name()) {
|
||||
MatchBnupdateRelu(cnode, relu_input, kernel_graph, candidate_fusion);
|
||||
} else if (relu_input->isa<CNode>() &&
|
||||
AnfAlgo::GetCNodeName(relu_input) == prim::kPrimDepthwiseConv2dNative->name()) {
|
||||
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true);
|
||||
} else if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
|
||||
auto eltwise_input = cnode->input(1);
|
||||
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimMatMul)) {
|
||||
MatchMatmulEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion);
|
||||
}
|
||||
if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimRelu)) {
|
||||
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) {
|
||||
MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
|
||||
} else if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) {
|
||||
MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
|
||||
}
|
||||
}
|
||||
} else if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimDepthwiseConv2dNative->name()) {
|
||||
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, false);
|
||||
|
|
|
@ -53,6 +53,8 @@ class BufferFusion : public Pass {
|
|||
const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
|
||||
void MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion, bool is_order);
|
||||
void MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion);
|
||||
void MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
|
||||
void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ relu_op_info = TBERegOp("ReLU") \
|
|||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -26,15 +26,15 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
||||
class TestHWBufferFusionPy : public BackendCommon {
|
||||
class TestHWBufferFusion : public BackendCommon {
|
||||
public:
|
||||
TestHWBufferFusionPy() : get_py_fun_("gtest_input.pre_activate.buffer_fusion_test", true) {}
|
||||
~TestHWBufferFusionPy() override = default;
|
||||
TestHWBufferFusion() : get_py_fun_("gtest_input.pre_activate.buffer_fusion_test", true) {}
|
||||
~TestHWBufferFusion() override = default;
|
||||
|
||||
UT::PyFuncGraphFetcher get_py_fun_;
|
||||
};
|
||||
|
||||
TEST_F(TestHWBufferFusionPy, test_tbe_eltwise_fusion_1) {
|
||||
TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_1) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_eltwise_fusion_1", "before");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
|
@ -90,7 +90,7 @@ TEST_F(TestHWBufferFusionPy, test_tbe_eltwise_fusion_1) {
|
|||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWBufferFusionPy, test_tbe_eltwise_fusion_2) {
|
||||
TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_2) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_eltwise_fusion_2", "before");
|
||||
std::vector<int> shp{32, 10};
|
||||
std::vector<int> shp_bias{10};
|
||||
|
@ -179,7 +179,7 @@ TEST_F(TestHWBufferFusionPy, test_tbe_eltwise_fusion_2) {
|
|||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWBufferFusionPy, test_tbe_reduce_eltwise_fusion) {
|
||||
TEST_F(TestHWBufferFusion, test_tbe_reduce_eltwise_fusion) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_reduce_eltwise_fusion", "before");
|
||||
std::vector<int> shp{32, 10};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
|
@ -265,5 +265,71 @@ TEST_F(TestHWBufferFusionPy, test_tbe_reduce_eltwise_fusion) {
|
|||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_tbe_reduce_eltwise_fusion", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWBufferFusion, test_tbe_matmul_eltwise_fusion) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tbe_matmul_eltwise_fusion", "before");
|
||||
std::vector<int> x_shp{2048, 768};
|
||||
std::vector<int> y_shp{768, 768};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, x_shp);
|
||||
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, y_shp);
|
||||
AbstractBasePtrList args_spec_list{x_abstract, y_abstract};
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto ret = kg->get_return();
|
||||
EXPECT_NE(ret->input(1), nullptr);
|
||||
auto tuple = ret->input(1);
|
||||
EXPECT_NE(tuple, nullptr);
|
||||
auto cast = tuple->cast<CNodePtr>()->input(1);
|
||||
EXPECT_NE(cast, nullptr);
|
||||
auto relu = cast->cast<CNodePtr>()->input(1);
|
||||
EXPECT_NE(relu, nullptr);
|
||||
auto matmul = relu->cast<CNodePtr>()->input(1);
|
||||
|
||||
KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsFormat({"NC1HWC0"});
|
||||
builder.SetOutputsFormat({"NC1HWC0"});
|
||||
builder.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
relu->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), relu.get());
|
||||
|
||||
KernelBuildInfoBuilder builder2;
|
||||
builder2.SetInputsFormat({"NC1HWC0", "NC1HWC0"});
|
||||
builder2.SetOutputsFormat({"NC1HWC0"});
|
||||
builder2.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()});
|
||||
builder2.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder2.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder2.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
builder2.SetProcessor(kernel::Processor::AICORE);
|
||||
builder2.SetKernelType(KernelType::TBE_KERNEL);
|
||||
matmul->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder2.Build(), matmul.get());
|
||||
|
||||
KernelBuildInfoBuilder builder1;
|
||||
builder1.SetInputsFormat({"NC1HWC0"});
|
||||
builder1.SetOutputsFormat({"NC1HWC0"});
|
||||
builder1.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder1.SetOutputsDeviceType({kFloat16->type_id()});
|
||||
builder1.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder1.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
builder1.SetProcessor(kernel::Processor::AICORE);
|
||||
builder1.SetKernelType(KernelType::TBE_KERNEL);
|
||||
cast->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), cast.get());
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
auto buffer_fusion_pass = std::make_shared<opt::BufferFusion>();
|
||||
pm->AddPass(buffer_fusion_pass);
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_tbe_matmul_eltwise_fusion", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -24,10 +24,12 @@ Reduce = P.ReduceOp()
|
|||
Biasadd = P.BiasAdd()
|
||||
Biasaddgrad = G.BiasAddGrad()
|
||||
Cast = P.Cast()
|
||||
MatMul = P.MatMul()
|
||||
|
||||
Fusion_relu_relu = Primitive('FusionOp_ReLU_ReLU')
|
||||
Fusion_biasadd = Primitive('FusionOp_ReLU_ReLU_ReLU_BiasAdd_ReLU_ReLU_ReLU')
|
||||
Fusion_biasaddgrad = Primitive('FusionOp_ReLU_ReLU_ReLU_BiasAddGrad_ReLU_ReLU_ReLU')
|
||||
Fusion_matmul_relu = Primitive('FusionOp_MatMul_ReLU')
|
||||
|
||||
Add = P.TensorAdd()
|
||||
Sub = P.Sub()
|
||||
|
@ -133,3 +135,23 @@ def test_conv_singlein_fusion(tag):
|
|||
return tuple
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_tbe_matmul_eltwise_fusion(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(x, y):
|
||||
matmul = MatMul(x, y)
|
||||
relu = Relu(matmul)
|
||||
res = Cast(relu, mstype.float16)
|
||||
return res
|
||||
|
||||
@fns
|
||||
def after(x, y):
|
||||
fusion = Fusion_matmul_relu(x, y)
|
||||
res = Cast(fusion)
|
||||
tuple = make_tuple(res)
|
||||
return tuple
|
||||
|
||||
return fns[tag]
|
||||
|
|
Loading…
Reference in New Issue