update akg/switch off h-fuse
This commit is contained in:
parent
d817105cba
commit
07a0c24126
2
akg
2
akg
|
@ -1 +1 @@
|
|||
Subproject commit a26bad300e932615d175fc7d34b9e213b5e811aa
|
||||
Subproject commit 1c72d655e24d6def68e3345f3e6cd1a86efde64b
|
|
@ -395,7 +395,7 @@ class GraphSplitByPattern:
|
|||
dom = areas[i]
|
||||
for a in areas[i + 1:]:
|
||||
if dom.check_acyclic(a) and a.check_acyclic(dom) and \
|
||||
selector(dom, a) and self.limit_area_size(dom, [a], 64):
|
||||
selector(dom, a) and self.limit_area_size(dom, [a], 64):
|
||||
dom.fuse(a)
|
||||
self.set_area_map(a.ops, dom)
|
||||
self.areas.remove(a)
|
||||
|
@ -909,28 +909,6 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
return [a], True
|
||||
return None
|
||||
|
||||
def _h_broadcast(dom, a):
|
||||
if dom.pattern > PrimLib.BROADCAST:
|
||||
return None
|
||||
return a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape
|
||||
|
||||
def _h_reduce(dom, a):
|
||||
if dom.pattern != PrimLib.REDUCE or dom.stitch_info.stitch_ops:
|
||||
return None
|
||||
dom_op = dom.ops[0]
|
||||
if not PrimLib.is_reduce(dom_op) or _is_atomic_add_available(dom):
|
||||
return None
|
||||
op = a.ops[0]
|
||||
return a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \
|
||||
PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \
|
||||
dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis")
|
||||
|
||||
def _h_opaque(dom, a):
|
||||
if dom.ops[0].prim not in {"StridedSlice"}:
|
||||
return None
|
||||
return a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \
|
||||
dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape
|
||||
|
||||
def _fuse_loop():
|
||||
changed = True
|
||||
while changed:
|
||||
|
@ -948,9 +926,6 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
if self.enable_stitch_fusion:
|
||||
changed = self.fuse(_reduce_stitch) or changed
|
||||
self.fuse(_transpose)
|
||||
self.hfuse(_h_broadcast)
|
||||
self.hfuse(_h_reduce)
|
||||
self.hfuse(_h_opaque)
|
||||
|
||||
def _fuse_once(fuse_func):
|
||||
if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \
|
||||
|
|
Loading…
Reference in New Issue