forked from mindspore-Ecosystem/mindspore
fuse matmul and elementwise in graphkernel
This commit is contained in:
parent
d0b7c98743
commit
c48c2430f0
2
akg
2
akg
|
@ -1 +1 @@
|
|||
Subproject commit e7a391c51e66975d46bacf6425ae8f27e1675f85
|
||||
Subproject commit e2a30d6b8ece4a69790ac9e37ae862fe8124ad7c
|
|
@ -575,6 +575,15 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
fused.append(a)
|
||||
return fused, True
|
||||
|
||||
def _matmul_depth(dom):
|
||||
if dom.dom_op().prim != "MatMul":
|
||||
return None
|
||||
fused = []
|
||||
for a, _ in dom.out_relations.items():
|
||||
if a.pattern == PrimLib.ELEMWISE and a.check_acyclic(dom):
|
||||
fused.append(a)
|
||||
return fused, False
|
||||
|
||||
changed = True
|
||||
while changed:
|
||||
changed = self.fuse(_reshape)
|
||||
|
@ -584,6 +593,7 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
changed = self.fuse(_reduce_width) or changed
|
||||
changed = self.fuse(_broadcast_depth) or changed
|
||||
changed = self.fuse(_broadcast_width) or changed
|
||||
changed = self.fuse(_matmul_depth) or changed
|
||||
|
||||
def split(graph, target, flags):
|
||||
"""Split graph"""
|
||||
|
|
|
@ -124,6 +124,7 @@
|
|||
#include "backend/optimizer/ascend/enhancer/split_n_optimizer.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "debug/dump_proto.h"
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
|
@ -429,12 +430,16 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGrap
|
|||
ub_fusion_pm->AddPass(std::make_shared<ConvSingleInFusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseFusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<BnupdateEltwiseEltwiseFusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<MatmulEltwiseFusionPass>(fusion_id_allocator));
|
||||
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||
ub_fusion_pm->AddPass(std::make_shared<MatmulEltwiseFusionPass>(fusion_id_allocator));
|
||||
}
|
||||
ub_fusion_pm->AddPass(std::make_shared<ConvDoubleInFusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<ReduceEltwiseFusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<SegmentEltwiseFusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<MultiOutputFusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<EltwiseFusionPass>(fusion_id_allocator));
|
||||
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||
ub_fusion_pm->AddPass(std::make_shared<EltwiseFusionPass>(fusion_id_allocator));
|
||||
}
|
||||
ub_fusion_pm->AddPass(std::make_shared<DepthwiseConvEltwiseFusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<MatmulConfusionTranposeFusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<BatchMatmulFusedMulAddFusionPass>(fusion_id_allocator));
|
||||
|
|
|
@ -35,6 +35,20 @@ bool HasPath(const AnfNodePtr &leaf, const AnfNodePtr &root, const FuncGraphMana
|
|||
static_cast<void>(DeepLinkedGraphSearch(leaf, IncludeUser));
|
||||
return result;
|
||||
}
|
||||
|
||||
// Update matmul's BuildInfo as last input changed
|
||||
void UpdateBuildInfo(const AnfNodePtr &matmul_node, const AnfNodePtr &cast_node) {
|
||||
std::vector<std::string> input_formats = AnfAlgo::GetAllInputFormats(matmul_node);
|
||||
std::vector<TypeId> input_types = AnfAlgo::GetAllInputDeviceTypes(matmul_node);
|
||||
input_types.pop_back();
|
||||
auto cast_types = AnfAlgo::GetAllInputDeviceTypes(cast_node);
|
||||
input_types.push_back(cast_types.front());
|
||||
std::vector<std::string> output_formats = AnfAlgo::GetAllOutputFormats(matmul_node);
|
||||
std::vector<TypeId> output_types = AnfAlgo::GetAllOutputDeviceTypes(matmul_node);
|
||||
auto graph_sel_info =
|
||||
BuildSelectKernelBuildInfo(input_formats, input_types, output_formats, output_types, matmul_node);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, matmul_node.get());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/* MatMul supports fp32 bias, so remove the redundant cast if cast cannot fuse forword
|
||||
|
@ -93,6 +107,7 @@ bool CastMatmulFusion::Run(const FuncGraphPtr &func_graph) {
|
|||
// Case1 : Cast is only used by matmul
|
||||
if (user_index_set.size() == 1) {
|
||||
mng->Replace(cast_node, (cast_node->cast<CNodePtr>())->input(1));
|
||||
UpdateBuildInfo(cnode, cast_node);
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
|
@ -109,6 +124,7 @@ bool CastMatmulFusion::Run(const FuncGraphPtr &func_graph) {
|
|||
cnode->set_input(4, (cast_node->cast<CNodePtr>())->input(1));
|
||||
mng->RemoveRoots();
|
||||
mng->KeepRoots({func_graph});
|
||||
UpdateBuildInfo(cnode, cast_node);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -306,6 +306,21 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> MixedNodesTransToGraph(
|
|||
return std::make_tuple(fg, inputs, outputs);
|
||||
}
|
||||
|
||||
kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::string> &inputs_format,
|
||||
const std::vector<TypeId> &inputs_type,
|
||||
const std::vector<std::string> &output_formats,
|
||||
const std::vector<TypeId> &output_types, const AnfNodePtr &node) {
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
|
||||
graph_info_builder.SetInputsFormat(inputs_format);
|
||||
graph_info_builder.SetInputsDeviceType(inputs_type);
|
||||
graph_info_builder.SetOutputsFormat(output_formats);
|
||||
graph_info_builder.SetOutputsDeviceType(output_types);
|
||||
graph_info_builder.SetProcessor(AnfAlgo::GetProcessor(node));
|
||||
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
return graph_info_builder.Build();
|
||||
}
|
||||
|
||||
void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
|
||||
const AnfNodePtrList &outputs) {
|
||||
std::vector<std::string> graph_input_format;
|
||||
|
|
|
@ -58,6 +58,10 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> MixedNodesTransToGraph(
|
|||
AnfNodePtrList *src_outputs = nullptr);
|
||||
void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
|
||||
const AnfNodePtrList &outputs);
|
||||
kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::string> &inputs_format,
|
||||
const std::vector<TypeId> &inputs_type,
|
||||
const std::vector<std::string> &output_formats,
|
||||
const std::vector<TypeId> &output_types, const AnfNodePtr &node);
|
||||
AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
|
||||
const AnfNodePtrList &outputs);
|
||||
void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode,
|
||||
|
|
|
@ -28,21 +28,6 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::string> &inputs_format,
|
||||
const std::vector<TypeId> &inputs_type,
|
||||
const std::vector<std::string> &output_formats,
|
||||
const std::vector<TypeId> &output_types, const CNodePtr &cnode) {
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
|
||||
graph_info_builder.SetInputsFormat(inputs_format);
|
||||
graph_info_builder.SetInputsDeviceType(inputs_type);
|
||||
graph_info_builder.SetOutputsFormat(output_formats);
|
||||
graph_info_builder.SetOutputsDeviceType(output_types);
|
||||
graph_info_builder.SetProcessor(AnfAlgo::GetProcessor(cnode));
|
||||
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
return graph_info_builder.Build();
|
||||
}
|
||||
|
||||
/**
|
||||
* If an Assign's source node was outputted with this Assign, the src-node should be removed from output list,
|
||||
* external users can use the dest-node under the premise of correct execution order.
|
||||
|
|
|
@ -49,7 +49,7 @@ def test_basic():
|
|||
output = get_output(i0, i1, i2, True)
|
||||
expect_np = expect.asnumpy().copy()
|
||||
output_np = output.asnumpy().copy()
|
||||
assert np.allclose(expect_np, output_np, 1.e-4, 1.e-7)
|
||||
assert np.allclose(expect_np, output_np, 5.e-3, 5.e-3)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
|
Loading…
Reference in New Issue