forked from mindspore-Ecosystem/mindspore
!23431 graph cycle bugfix of horizontal fusion
Merge pull request !23431 from Gaoxiong/master
This commit is contained in:
commit
cc2c57bad4
|
@ -390,30 +390,17 @@ class GraphSplitByPattern:
|
|||
|
||||
def hfuse(self, selector):
|
||||
"""Fuse horizontal areas with same input tensor"""
|
||||
|
||||
def _do_fuse(areas):
|
||||
for i in range(len(areas) - 1):
|
||||
sibling = []
|
||||
dom = areas[i]
|
||||
for a in areas[i + 1:]:
|
||||
if areas[i].unique_id < a.unique_id:
|
||||
id1, id2 = areas[i].unique_id, a.unique_id
|
||||
else:
|
||||
id1, id2 = a.unique_id, areas[i].unique_id
|
||||
if not self.reach_tab.reachable(id1, id2):
|
||||
sibling.append(a)
|
||||
if sibling:
|
||||
result = selector(areas[i], sibling)
|
||||
if result:
|
||||
fuse_areas = self.limit_area_size(areas[i], result, 64)
|
||||
if not fuse_areas:
|
||||
continue
|
||||
for area in fuse_areas:
|
||||
areas[i].fuse(area)
|
||||
self.set_area_map(area.ops, areas[i])
|
||||
self.areas.remove(area)
|
||||
if dom.check_acyclic(a) and a.check_acyclic(dom) and \
|
||||
selector(dom, a) and self.limit_area_size(dom, [a], 64):
|
||||
dom.fuse(a)
|
||||
self.set_area_map(a.ops, dom)
|
||||
self.areas.remove(a)
|
||||
return True
|
||||
return False
|
||||
|
||||
changed = False
|
||||
while True:
|
||||
for dom in self.areas:
|
||||
|
@ -922,39 +909,27 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
return [a], True
|
||||
return None
|
||||
|
||||
def _h_broadcast(dom, sibling):
|
||||
def _h_broadcast(dom, a):
|
||||
if dom.pattern > PrimLib.BROADCAST:
|
||||
return None
|
||||
fused = []
|
||||
for a in sibling:
|
||||
if a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape:
|
||||
fused.append(a)
|
||||
return fused
|
||||
return a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape
|
||||
|
||||
def _h_reduce(dom, sibling):
|
||||
def _h_reduce(dom, a):
|
||||
if dom.pattern != PrimLib.REDUCE or dom.stitch_info.stitch_ops:
|
||||
return None
|
||||
dom_op = dom.ops[0]
|
||||
if not PrimLib.is_reduce(dom_op) or _is_atomic_add_available(dom):
|
||||
return None
|
||||
fused = []
|
||||
for a in sibling:
|
||||
op = a.ops[0]
|
||||
if a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \
|
||||
PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \
|
||||
dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis"):
|
||||
fused.append(a)
|
||||
return fused
|
||||
op = a.ops[0]
|
||||
return a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \
|
||||
PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \
|
||||
dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis")
|
||||
|
||||
def _h_opaque(dom, sibling):
|
||||
def _h_opaque(dom, a):
|
||||
if dom.ops[0].prim not in {"StridedSlice"}:
|
||||
return None
|
||||
fused = []
|
||||
for a in sibling:
|
||||
if a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \
|
||||
dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape:
|
||||
fused.append(a)
|
||||
return fused
|
||||
return a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \
|
||||
dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape
|
||||
|
||||
def _fuse_loop():
|
||||
changed = True
|
||||
|
|
Loading…
Reference in New Issue