diff --git a/akg b/akg index e7a391c51e6..e2a30d6b8ec 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit e7a391c51e66975d46bacf6425ae8f27e1675f85 +Subproject commit e2a30d6b8ece4a69790ac9e37ae862fe8124ad7c diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index 05654b376f5..3f1b5bc18bd 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -575,6 +575,15 @@ class GraphSplitAscend(GraphSplitByPattern): fused.append(a) return fused, True + def _matmul_depth(dom): + if dom.dom_op().prim != "MatMul": + return None + fused = [] + for a, _ in dom.out_relations.items(): + if a.pattern == PrimLib.ELEMWISE and a.check_acyclic(dom): + fused.append(a) + return fused, False + changed = True while changed: changed = self.fuse(_reshape) @@ -584,6 +593,7 @@ class GraphSplitAscend(GraphSplitByPattern): changed = self.fuse(_reduce_width) or changed changed = self.fuse(_broadcast_depth) or changed changed = self.fuse(_broadcast_width) or changed + changed = self.fuse(_matmul_depth) or changed def split(graph, target, flags): """Split graph""" diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index bd5cedc9a56..0b67d12f306 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -124,6 +124,7 @@ #include "backend/optimizer/ascend/enhancer/split_n_optimizer.h" #include "utils/ms_context.h" #include "utils/config_manager.h" +#include "utils/context/graph_kernel_flags.h" #include "debug/anf_ir_dump.h" #include "debug/dump_proto.h" #ifdef ENABLE_DUMP_IR @@ -429,12 +430,16 @@ void AscendBackendUBFusionOptimization(const std::shared_ptrAddPass(std::make_shared(fusion_id_allocator)); ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) { + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + } ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) { + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + } ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/cast_matmul_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/cast_matmul_fusion.cc index 7c1d9913dc4..5bb94b6d0c1 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/cast_matmul_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/cast_matmul_fusion.cc @@ -35,6 +35,20 @@ bool HasPath(const AnfNodePtr &leaf, const AnfNodePtr &root, const FuncGraphMana static_cast(DeepLinkedGraphSearch(leaf, IncludeUser)); return result; } + +// Update matmul's BuildInfo as last input changed +void UpdateBuildInfo(const AnfNodePtr &matmul_node, const AnfNodePtr &cast_node) { + std::vector input_formats = AnfAlgo::GetAllInputFormats(matmul_node); + std::vector input_types = AnfAlgo::GetAllInputDeviceTypes(matmul_node); + input_types.pop_back(); + auto cast_types = AnfAlgo::GetAllInputDeviceTypes(cast_node); + input_types.push_back(cast_types.front()); + std::vector output_formats = AnfAlgo::GetAllOutputFormats(matmul_node); + std::vector output_types = AnfAlgo::GetAllOutputDeviceTypes(matmul_node); + auto graph_sel_info = + BuildSelectKernelBuildInfo(input_formats, input_types, output_formats, output_types, matmul_node); + AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, matmul_node.get()); +} } // namespace /* MatMul supports fp32 bias, so remove the redundant cast if cast cannot fuse forword @@ -93,6 +107,7 @@ bool CastMatmulFusion::Run(const FuncGraphPtr &func_graph) { // Case1 : Cast is only used by matmul if (user_index_set.size() == 1) { mng->Replace(cast_node, (cast_node->cast())->input(1)); + UpdateBuildInfo(cnode, cast_node); changed = true; continue; } @@ -109,6 +124,7 @@ bool CastMatmulFusion::Run(const FuncGraphPtr &func_graph) { cnode->set_input(4, (cast_node->cast())->input(1)); mng->RemoveRoots(); mng->KeepRoots({func_graph}); + UpdateBuildInfo(cnode, cast_node); changed = true; } } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index 117cf74bb1c..e8b83187ab2 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -306,6 +306,21 @@ std::tuple MixedNodesTransToGraph( return std::make_tuple(fg, inputs, outputs); } +kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector &inputs_format, + const std::vector &inputs_type, + const std::vector &output_formats, + const std::vector &output_types, const AnfNodePtr &node) { + kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder; + graph_info_builder.SetInputsFormat(inputs_format); + graph_info_builder.SetInputsDeviceType(inputs_type); + graph_info_builder.SetOutputsFormat(output_formats); + graph_info_builder.SetOutputsDeviceType(output_types); + graph_info_builder.SetProcessor(AnfAlgo::GetProcessor(node)); + graph_info_builder.SetKernelType(KernelType::AKG_KERNEL); + graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE); + return graph_info_builder.Build(); +} + void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, const AnfNodePtrList &outputs) { std::vector graph_input_format; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h index 79b853bbb7a..455d2dd2c2c 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h @@ -58,6 +58,10 @@ std::tuple MixedNodesTransToGraph( AnfNodePtrList *src_outputs = nullptr); void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, const AnfNodePtrList &outputs); +kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector &inputs_format, + const std::vector &inputs_type, + const std::vector &output_formats, + const std::vector &output_types, const AnfNodePtr &node); AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs, const AnfNodePtrList &outputs); void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode, diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_assign.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_assign.cc index 2048cdda9b1..dd1904179e6 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_assign.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_assign.cc @@ -28,21 +28,6 @@ namespace mindspore { namespace opt { namespace { -kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector &inputs_format, - const std::vector &inputs_type, - const std::vector &output_formats, - const std::vector &output_types, const CNodePtr &cnode) { - kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder; - graph_info_builder.SetInputsFormat(inputs_format); - graph_info_builder.SetInputsDeviceType(inputs_type); - graph_info_builder.SetOutputsFormat(output_formats); - graph_info_builder.SetOutputsDeviceType(output_types); - graph_info_builder.SetProcessor(AnfAlgo::GetProcessor(cnode)); - graph_info_builder.SetKernelType(KernelType::AKG_KERNEL); - graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE); - return graph_info_builder.Build(); -} - /** * If an Assign's source node was outputted with this Assign, the src-node should be removed from output list, * external users can use the dest-node under the premise of correct execution order. diff --git a/tests/st/ops/graph_kernel/test_cast_matmul_fusion.py b/tests/st/ops/graph_kernel/test_cast_matmul_fusion.py index 49731fbfff6..6612f2b094c 100644 --- a/tests/st/ops/graph_kernel/test_cast_matmul_fusion.py +++ b/tests/st/ops/graph_kernel/test_cast_matmul_fusion.py @@ -49,7 +49,7 @@ def test_basic(): output = get_output(i0, i1, i2, True) expect_np = expect.asnumpy().copy() output_np = output.asnumpy().copy() - assert np.allclose(expect_np, output_np, 1.e-4, 1.e-7) + assert np.allclose(expect_np, output_np, 5.e-3, 5.e-3) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training