From 07a0c2412604bdbd60d19d4aa0949667a7e0288b Mon Sep 17 00:00:00 2001 From: Yang Jiao Date: Fri, 24 Sep 2021 14:45:03 +0800 Subject: [PATCH] update akg/switch off h-fuse --- akg | 2 +- .../graph_kernel/model/graph_split.py | 27 +------------------ 2 files changed, 2 insertions(+), 27 deletions(-) diff --git a/akg b/akg index a26bad300e9..1c72d655e24 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit a26bad300e932615d175fc7d34b9e213b5e811aa +Subproject commit 1c72d655e24d6def68e3345f3e6cd1a86efde64b diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index 15cdda08159..745b4382963 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -395,7 +395,7 @@ class GraphSplitByPattern: dom = areas[i] for a in areas[i + 1:]: if dom.check_acyclic(a) and a.check_acyclic(dom) and \ - selector(dom, a) and self.limit_area_size(dom, [a], 64): + selector(dom, a) and self.limit_area_size(dom, [a], 64): dom.fuse(a) self.set_area_map(a.ops, dom) self.areas.remove(a) @@ -909,28 +909,6 @@ class GraphSplitGpu(GraphSplitByPattern): return [a], True return None - def _h_broadcast(dom, a): - if dom.pattern > PrimLib.BROADCAST: - return None - return a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape - - def _h_reduce(dom, a): - if dom.pattern != PrimLib.REDUCE or dom.stitch_info.stitch_ops: - return None - dom_op = dom.ops[0] - if not PrimLib.is_reduce(dom_op) or _is_atomic_add_available(dom): - return None - op = a.ops[0] - return a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \ - PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \ - dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis") - - def _h_opaque(dom, a): - if dom.ops[0].prim not in {"StridedSlice"}: - return None - return a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \ - dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape - def _fuse_loop(): changed = True while changed: @@ -948,9 +926,6 @@ class GraphSplitGpu(GraphSplitByPattern): if self.enable_stitch_fusion: changed = self.fuse(_reduce_stitch) or changed self.fuse(_transpose) - self.hfuse(_h_broadcast) - self.hfuse(_h_reduce) - self.hfuse(_h_opaque) def _fuse_once(fuse_func): if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \