diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index 4de3c049742..15cdda08159 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -390,30 +390,17 @@ class GraphSplitByPattern: def hfuse(self, selector): """Fuse horizontal areas with same input tensor""" - def _do_fuse(areas): for i in range(len(areas) - 1): - sibling = [] + dom = areas[i] for a in areas[i + 1:]: - if areas[i].unique_id < a.unique_id: - id1, id2 = areas[i].unique_id, a.unique_id - else: - id1, id2 = a.unique_id, areas[i].unique_id - if not self.reach_tab.reachable(id1, id2): - sibling.append(a) - if sibling: - result = selector(areas[i], sibling) - if result: - fuse_areas = self.limit_area_size(areas[i], result, 64) - if not fuse_areas: - continue - for area in fuse_areas: - areas[i].fuse(area) - self.set_area_map(area.ops, areas[i]) - self.areas.remove(area) + if dom.check_acyclic(a) and a.check_acyclic(dom) and \ + selector(dom, a) and self.limit_area_size(dom, [a], 64): + dom.fuse(a) + self.set_area_map(a.ops, dom) + self.areas.remove(a) return True return False - changed = False while True: for dom in self.areas: @@ -922,39 +909,27 @@ class GraphSplitGpu(GraphSplitByPattern): return [a], True return None - def _h_broadcast(dom, sibling): + def _h_broadcast(dom, a): if dom.pattern > PrimLib.BROADCAST: return None - fused = [] - for a in sibling: - if a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape: - fused.append(a) - return fused + return a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape - def _h_reduce(dom, sibling): + def _h_reduce(dom, a): if dom.pattern != PrimLib.REDUCE or dom.stitch_info.stitch_ops: return None dom_op = dom.ops[0] if not PrimLib.is_reduce(dom_op) or _is_atomic_add_available(dom): return None - fused = [] - for a in sibling: - op = a.ops[0] - if a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \ - PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \ - dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis"): - fused.append(a) - return fused + op = a.ops[0] + return a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \ + PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \ + dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis") - def _h_opaque(dom, sibling): + def _h_opaque(dom, a): if dom.ops[0].prim not in {"StridedSlice"}: return None - fused = [] - for a in sibling: - if a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \ - dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape: - fused.append(a) - return fused + return a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \ + dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape def _fuse_loop(): changed = True