!22647 Support horizontal fusion

Merge pull request !22647 from Gaoxiong/master
This commit is contained in:
i-robot 2021-08-31 11:37:48 +00:00 committed by Gitee
commit 214556ff0c
1 changed files with 90 additions and 3 deletions

View File

@ -142,7 +142,10 @@ class GraphSplitByPattern:
relations[a] = max(r, relations[a]) if a in relations else r
def _update_pattern():
self.pattern = max(self.pattern, area.pattern, self.in_relations[area])
if area.pattern > self.pattern:
self.pattern = area.pattern
if area in self.in_relations and self.in_relations[area] > self.pattern:
self.pattern = self.in_relations[area]
def _fuse_relation(self_relations, new_relations):
for a, r in new_relations.items():
@ -274,9 +277,8 @@ class GraphSplitByPattern:
"""Set default mode"""
area.mode = self.get_default_mode(area.ops[0])
def limit_area_size(self, dominant, fuse_areas):
def limit_area_size(self, dominant, fuse_areas, limit_size=200):
"""Remove some areas if the size is too large"""
limit_size = 200 # an experience number
area_sizes = map(lambda area: len(area.ops), fuse_areas)
dom_size = len(dominant.ops)
if dom_size + prod_reduce(lambda x, y: x + y, area_sizes) <= limit_size:
@ -323,6 +325,53 @@ class GraphSplitByPattern:
changed = changed or do_again
return changed
def hfuse(self, selector):
"""Fuse horizontal areas with same input tensor"""
def _do_fuse(areas):
for i in range(len(areas) - 1):
sibling = []
for a in areas[i + 1:]:
if areas[i].unique_id < a.unique_id:
id1, id2 = areas[i].unique_id, a.unique_id
else:
id1, id2 = a.unique_id, areas[i].unique_id
if not self.reach_tab.reachable(id1, id2):
sibling.append(a)
if sibling:
result = selector(areas[i], sibling)
if result:
fuse_areas = self.limit_area_size(areas[i], result, 64)
if not fuse_areas:
continue
for area in fuse_areas:
areas[i].fuse(area)
self.set_area_map(area.ops, areas[i])
self.areas.remove(area)
return True
return False
changed = False
while True:
for dom in self.areas:
if len(dom.out_relations) > 1 and _do_fuse(list(dom.out_relations.keys())):
changed = True
break
else:
break
inputs, _ = self.graph.deduce_parameters()
while True:
for t in inputs:
areas = []
for op in t.to_ops:
a = self.area_map[op]
if a in self.areas and a not in areas:
areas.append(a)
if len(areas) > 1 and _do_fuse(areas):
changed = True
break
else:
break
return changed
def fuse_recom(self, selector):
"""Fuse recompute area to its user"""
for dominant in [self.recom_area, self.recom_user]:
@ -334,6 +383,7 @@ class GraphSplitByPattern:
continue
if fuse_areas[0] in [self.recom_area, self.recom_user]:
self.recom_user.fuse(self.recom_area)
self.set_area_map(self.recom_area.ops, self.recom_user)
self.recom_res = True
return True
return False
@ -847,6 +897,40 @@ class GraphSplitGpu(GraphSplitByPattern):
return [a], True
return None
def _h_broadcast(dom, sibling):
if dom.pattern > PrimLib.BROADCAST:
return None
fused = []
for a in sibling:
if a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape:
fused.append(a)
return fused
def _h_reduce(dom, sibling):
if dom.pattern != PrimLib.REDUCE or dom.stitch_info.stitch_ops:
return None
dom_op = dom.ops[0]
if not PrimLib.is_reduce(dom_op) or _is_atomic_add_available(dom):
return None
fused = []
for a in sibling:
op = a.ops[0]
if a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \
PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \
dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis"):
fused.append(a)
return fused
def _h_opaque(dom, sibling):
if dom.ops[0].prim not in {"StridedSlice"}:
return None
fused = []
for a in sibling:
if a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \
dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape:
fused.append(a)
return fused
def _fuse_loop():
changed = True
while changed:
@ -864,6 +948,9 @@ class GraphSplitGpu(GraphSplitByPattern):
if enable_stitch_fusion:
changed = self.fuse(_reduce_stitch) or changed
self.fuse(_transpose)
self.hfuse(_h_broadcast)
self.hfuse(_h_reduce)
self.hfuse(_h_opaque)
def _fuse_once(fuse_func):
if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \