update akg/switch off h-fuse

This commit is contained in:
Yang Jiao 2021-09-24 14:45:03 +08:00
parent d817105cba
commit 07a0c24126
2 changed files with 2 additions and 27 deletions

2
akg

@ -1 +1 @@
Subproject commit a26bad300e932615d175fc7d34b9e213b5e811aa
Subproject commit 1c72d655e24d6def68e3345f3e6cd1a86efde64b

View File

@ -395,7 +395,7 @@ class GraphSplitByPattern:
dom = areas[i]
for a in areas[i + 1:]:
if dom.check_acyclic(a) and a.check_acyclic(dom) and \
selector(dom, a) and self.limit_area_size(dom, [a], 64):
selector(dom, a) and self.limit_area_size(dom, [a], 64):
dom.fuse(a)
self.set_area_map(a.ops, dom)
self.areas.remove(a)
@ -909,28 +909,6 @@ class GraphSplitGpu(GraphSplitByPattern):
return [a], True
return None
def _h_broadcast(dom, a):
if dom.pattern > PrimLib.BROADCAST:
return None
return a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape
def _h_reduce(dom, a):
if dom.pattern != PrimLib.REDUCE or dom.stitch_info.stitch_ops:
return None
dom_op = dom.ops[0]
if not PrimLib.is_reduce(dom_op) or _is_atomic_add_available(dom):
return None
op = a.ops[0]
return a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \
PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \
dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis")
def _h_opaque(dom, a):
if dom.ops[0].prim not in {"StridedSlice"}:
return None
return a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \
dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape
def _fuse_loop():
changed = True
while changed:
@ -948,9 +926,6 @@ class GraphSplitGpu(GraphSplitByPattern):
if self.enable_stitch_fusion:
changed = self.fuse(_reduce_stitch) or changed
self.fuse(_transpose)
self.hfuse(_h_broadcast)
self.hfuse(_h_reduce)
self.hfuse(_h_opaque)
def _fuse_once(fuse_func):
if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \