fuse matmul and elementwise in graphkernel

This commit is contained in:
lingyunli63 2021-04-07 18:14:27 +08:00
parent d0b7c98743
commit c48c2430f0
8 changed files with 54 additions and 19 deletions

2
akg

@ -1 +1 @@
Subproject commit e7a391c51e66975d46bacf6425ae8f27e1675f85
Subproject commit e2a30d6b8ece4a69790ac9e37ae862fe8124ad7c

View File

@ -575,6 +575,15 @@ class GraphSplitAscend(GraphSplitByPattern):
fused.append(a)
return fused, True
def _matmul_depth(dom):
if dom.dom_op().prim != "MatMul":
return None
fused = []
for a, _ in dom.out_relations.items():
if a.pattern == PrimLib.ELEMWISE and a.check_acyclic(dom):
fused.append(a)
return fused, False
changed = True
while changed:
changed = self.fuse(_reshape)
@ -584,6 +593,7 @@ class GraphSplitAscend(GraphSplitByPattern):
changed = self.fuse(_reduce_width) or changed
changed = self.fuse(_broadcast_depth) or changed
changed = self.fuse(_broadcast_width) or changed
changed = self.fuse(_matmul_depth) or changed
def split(graph, target, flags):
"""Split graph"""

View File

@ -124,6 +124,7 @@
#include "backend/optimizer/ascend/enhancer/split_n_optimizer.h"
#include "utils/ms_context.h"
#include "utils/config_manager.h"
#include "utils/context/graph_kernel_flags.h"
#include "debug/anf_ir_dump.h"
#include "debug/dump_proto.h"
#ifdef ENABLE_DUMP_IR
@ -429,12 +430,16 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGrap
ub_fusion_pm->AddPass(std::make_shared<ConvSingleInFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<MatmulEltwiseFusionPass>(fusion_id_allocator));
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
ub_fusion_pm->AddPass(std::make_shared<MatmulEltwiseFusionPass>(fusion_id_allocator));
}
ub_fusion_pm->AddPass(std::make_shared<ConvDoubleInFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<ReduceEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<SegmentEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<MultiOutputFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<EltwiseFusionPass>(fusion_id_allocator));
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
ub_fusion_pm->AddPass(std::make_shared<EltwiseFusionPass>(fusion_id_allocator));
}
ub_fusion_pm->AddPass(std::make_shared<DepthwiseConvEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<MatmulConfusionTranposeFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<BatchMatmulFusedMulAddFusionPass>(fusion_id_allocator));

View File

@ -35,6 +35,20 @@ bool HasPath(const AnfNodePtr &leaf, const AnfNodePtr &root, const FuncGraphMana
static_cast<void>(DeepLinkedGraphSearch(leaf, IncludeUser));
return result;
}
// Update matmul's BuildInfo as last input changed
void UpdateBuildInfo(const AnfNodePtr &matmul_node, const AnfNodePtr &cast_node) {
std::vector<std::string> input_formats = AnfAlgo::GetAllInputFormats(matmul_node);
std::vector<TypeId> input_types = AnfAlgo::GetAllInputDeviceTypes(matmul_node);
input_types.pop_back();
auto cast_types = AnfAlgo::GetAllInputDeviceTypes(cast_node);
input_types.push_back(cast_types.front());
std::vector<std::string> output_formats = AnfAlgo::GetAllOutputFormats(matmul_node);
std::vector<TypeId> output_types = AnfAlgo::GetAllOutputDeviceTypes(matmul_node);
auto graph_sel_info =
BuildSelectKernelBuildInfo(input_formats, input_types, output_formats, output_types, matmul_node);
AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, matmul_node.get());
}
} // namespace
/* MatMul supports fp32 bias, so remove the redundant cast if cast cannot fuse forword
@ -93,6 +107,7 @@ bool CastMatmulFusion::Run(const FuncGraphPtr &func_graph) {
// Case1 : Cast is only used by matmul
if (user_index_set.size() == 1) {
mng->Replace(cast_node, (cast_node->cast<CNodePtr>())->input(1));
UpdateBuildInfo(cnode, cast_node);
changed = true;
continue;
}
@ -109,6 +124,7 @@ bool CastMatmulFusion::Run(const FuncGraphPtr &func_graph) {
cnode->set_input(4, (cast_node->cast<CNodePtr>())->input(1));
mng->RemoveRoots();
mng->KeepRoots({func_graph});
UpdateBuildInfo(cnode, cast_node);
changed = true;
}
}

View File

@ -306,6 +306,21 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> MixedNodesTransToGraph(
return std::make_tuple(fg, inputs, outputs);
}
kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::string> &inputs_format,
const std::vector<TypeId> &inputs_type,
const std::vector<std::string> &output_formats,
const std::vector<TypeId> &output_types, const AnfNodePtr &node) {
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
graph_info_builder.SetInputsFormat(inputs_format);
graph_info_builder.SetInputsDeviceType(inputs_type);
graph_info_builder.SetOutputsFormat(output_formats);
graph_info_builder.SetOutputsDeviceType(output_types);
graph_info_builder.SetProcessor(AnfAlgo::GetProcessor(node));
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
return graph_info_builder.Build();
}
void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
const AnfNodePtrList &outputs) {
std::vector<std::string> graph_input_format;

View File

@ -58,6 +58,10 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> MixedNodesTransToGraph(
AnfNodePtrList *src_outputs = nullptr);
void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
const AnfNodePtrList &outputs);
kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::string> &inputs_format,
const std::vector<TypeId> &inputs_type,
const std::vector<std::string> &output_formats,
const std::vector<TypeId> &output_types, const AnfNodePtr &node);
AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
const AnfNodePtrList &outputs);
void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode,

View File

@ -28,21 +28,6 @@
namespace mindspore {
namespace opt {
namespace {
kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::string> &inputs_format,
const std::vector<TypeId> &inputs_type,
const std::vector<std::string> &output_formats,
const std::vector<TypeId> &output_types, const CNodePtr &cnode) {
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
graph_info_builder.SetInputsFormat(inputs_format);
graph_info_builder.SetInputsDeviceType(inputs_type);
graph_info_builder.SetOutputsFormat(output_formats);
graph_info_builder.SetOutputsDeviceType(output_types);
graph_info_builder.SetProcessor(AnfAlgo::GetProcessor(cnode));
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
return graph_info_builder.Build();
}
/**
* If an Assign's source node was outputted with this Assign, the src-node should be removed from output list,
* external users can use the dest-node under the premise of correct execution order.

View File

@ -49,7 +49,7 @@ def test_basic():
output = get_output(i0, i1, i2, True)
expect_np = expect.asnumpy().copy()
output_np = output.asnumpy().copy()
assert np.allclose(expect_np, output_np, 1.e-4, 1.e-7)
assert np.allclose(expect_np, output_np, 5.e-3, 5.e-3)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training