From 0586709409c2da2182a8a30bef1823b81d7364c7 Mon Sep 17 00:00:00 2001 From: Gaoxiong Date: Tue, 31 Aug 2021 09:52:42 +0800 Subject: [PATCH] support horizontal fustion --- .../graph_kernel/model/graph_split.py | 93 ++++++++++++++++++- 1 file changed, 90 insertions(+), 3 deletions(-) diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index 6bcb74baa03..06d7636149f 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -142,7 +142,10 @@ class GraphSplitByPattern: relations[a] = max(r, relations[a]) if a in relations else r def _update_pattern(): - self.pattern = max(self.pattern, area.pattern, self.in_relations[area]) + if area.pattern > self.pattern: + self.pattern = area.pattern + if area in self.in_relations and self.in_relations[area] > self.pattern: + self.pattern = self.in_relations[area] def _fuse_relation(self_relations, new_relations): for a, r in new_relations.items(): @@ -274,9 +277,8 @@ class GraphSplitByPattern: """Set default mode""" area.mode = self.get_default_mode(area.ops[0]) - def limit_area_size(self, dominant, fuse_areas): + def limit_area_size(self, dominant, fuse_areas, limit_size=200): """Remove some areas if the size is too large""" - limit_size = 200 # an experience number area_sizes = map(lambda area: len(area.ops), fuse_areas) dom_size = len(dominant.ops) if dom_size + prod_reduce(lambda x, y: x + y, area_sizes) <= limit_size: @@ -323,6 +325,53 @@ class GraphSplitByPattern: changed = changed or do_again return changed + def hfuse(self, selector): + """Fuse horizontal areas with same input tensor""" + def _do_fuse(areas): + for i in range(len(areas) - 1): + sibling = [] + for a in areas[i + 1:]: + if areas[i].unique_id < a.unique_id: + id1, id2 = areas[i].unique_id, a.unique_id + else: + id1, id2 = a.unique_id, areas[i].unique_id + if not self.reach_tab.reachable(id1, id2): + sibling.append(a) + if sibling: + result = selector(areas[i], sibling) + if result: + fuse_areas = self.limit_area_size(areas[i], result, 64) + if not fuse_areas: + continue + for area in fuse_areas: + areas[i].fuse(area) + self.set_area_map(area.ops, areas[i]) + self.areas.remove(area) + return True + return False + changed = False + while True: + for dom in self.areas: + if len(dom.out_relations) > 1 and _do_fuse(list(dom.out_relations.keys())): + changed = True + break + else: + break + inputs, _ = self.graph.deduce_parameters() + while True: + for t in inputs: + areas = [] + for op in t.to_ops: + a = self.area_map[op] + if a in self.areas and a not in areas: + areas.append(a) + if len(areas) > 1 and _do_fuse(areas): + changed = True + break + else: + break + return changed + def fuse_recom(self, selector): """Fuse recompute area to its user""" for dominant in [self.recom_area, self.recom_user]: @@ -334,6 +383,7 @@ class GraphSplitByPattern: continue if fuse_areas[0] in [self.recom_area, self.recom_user]: self.recom_user.fuse(self.recom_area) + self.set_area_map(self.recom_area.ops, self.recom_user) self.recom_res = True return True return False @@ -847,6 +897,40 @@ class GraphSplitGpu(GraphSplitByPattern): return [a], True return None + def _h_broadcast(dom, sibling): + if dom.pattern > PrimLib.BROADCAST: + return None + fused = [] + for a in sibling: + if a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape: + fused.append(a) + return fused + + def _h_reduce(dom, sibling): + if dom.pattern != PrimLib.REDUCE or dom.stitch_info.stitch_ops: + return None + dom_op = dom.ops[0] + if not PrimLib.is_reduce(dom_op) or _is_atomic_add_available(dom): + return None + fused = [] + for a in sibling: + op = a.ops[0] + if a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \ + PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \ + dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis"): + fused.append(a) + return fused + + def _h_opaque(dom, sibling): + if dom.ops[0].prim not in {"StridedSlice"}: + return None + fused = [] + for a in sibling: + if a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \ + dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape: + fused.append(a) + return fused + def _fuse_loop(): changed = True while changed: @@ -864,6 +948,9 @@ class GraphSplitGpu(GraphSplitByPattern): if enable_stitch_fusion: changed = self.fuse(_reduce_stitch) or changed self.fuse(_transpose) + self.hfuse(_h_broadcast) + self.hfuse(_h_reduce) + self.hfuse(_h_opaque) def _fuse_once(fuse_func): if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \