!23431 graph cycle bugfix of horizontal fusion

Merge pull request !23431 from Gaoxiong/master
This commit is contained in:
i-robot 2021-09-14 09:45:42 +00:00 committed by Gitee
commit cc2c57bad4
1 changed files with 16 additions and 41 deletions

View File

@ -390,30 +390,17 @@ class GraphSplitByPattern:
def hfuse(self, selector):
"""Fuse horizontal areas with same input tensor"""
def _do_fuse(areas):
for i in range(len(areas) - 1):
sibling = []
dom = areas[i]
for a in areas[i + 1:]:
if areas[i].unique_id < a.unique_id:
id1, id2 = areas[i].unique_id, a.unique_id
else:
id1, id2 = a.unique_id, areas[i].unique_id
if not self.reach_tab.reachable(id1, id2):
sibling.append(a)
if sibling:
result = selector(areas[i], sibling)
if result:
fuse_areas = self.limit_area_size(areas[i], result, 64)
if not fuse_areas:
continue
for area in fuse_areas:
areas[i].fuse(area)
self.set_area_map(area.ops, areas[i])
self.areas.remove(area)
if dom.check_acyclic(a) and a.check_acyclic(dom) and \
selector(dom, a) and self.limit_area_size(dom, [a], 64):
dom.fuse(a)
self.set_area_map(a.ops, dom)
self.areas.remove(a)
return True
return False
changed = False
while True:
for dom in self.areas:
@ -922,39 +909,27 @@ class GraphSplitGpu(GraphSplitByPattern):
return [a], True
return None
def _h_broadcast(dom, sibling):
def _h_broadcast(dom, a):
if dom.pattern > PrimLib.BROADCAST:
return None
fused = []
for a in sibling:
if a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape:
fused.append(a)
return fused
return a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape
def _h_reduce(dom, sibling):
def _h_reduce(dom, a):
if dom.pattern != PrimLib.REDUCE or dom.stitch_info.stitch_ops:
return None
dom_op = dom.ops[0]
if not PrimLib.is_reduce(dom_op) or _is_atomic_add_available(dom):
return None
fused = []
for a in sibling:
op = a.ops[0]
if a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \
PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \
dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis"):
fused.append(a)
return fused
op = a.ops[0]
return a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \
PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \
dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis")
def _h_opaque(dom, sibling):
def _h_opaque(dom, a):
if dom.ops[0].prim not in {"StridedSlice"}:
return None
fused = []
for a in sibling:
if a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \
dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape:
fused.append(a)
return fused
return a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \
dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape
def _fuse_loop():
changed = True