forked from mindspore-Ecosystem/mindspore
!22647 Support horizontal fusion
Merge pull request !22647 from Gaoxiong/master
This commit is contained in:
commit
214556ff0c
|
@ -142,7 +142,10 @@ class GraphSplitByPattern:
|
|||
relations[a] = max(r, relations[a]) if a in relations else r
|
||||
|
||||
def _update_pattern():
|
||||
self.pattern = max(self.pattern, area.pattern, self.in_relations[area])
|
||||
if area.pattern > self.pattern:
|
||||
self.pattern = area.pattern
|
||||
if area in self.in_relations and self.in_relations[area] > self.pattern:
|
||||
self.pattern = self.in_relations[area]
|
||||
|
||||
def _fuse_relation(self_relations, new_relations):
|
||||
for a, r in new_relations.items():
|
||||
|
@ -274,9 +277,8 @@ class GraphSplitByPattern:
|
|||
"""Set default mode"""
|
||||
area.mode = self.get_default_mode(area.ops[0])
|
||||
|
||||
def limit_area_size(self, dominant, fuse_areas):
|
||||
def limit_area_size(self, dominant, fuse_areas, limit_size=200):
|
||||
"""Remove some areas if the size is too large"""
|
||||
limit_size = 200 # an experience number
|
||||
area_sizes = map(lambda area: len(area.ops), fuse_areas)
|
||||
dom_size = len(dominant.ops)
|
||||
if dom_size + prod_reduce(lambda x, y: x + y, area_sizes) <= limit_size:
|
||||
|
@ -323,6 +325,53 @@ class GraphSplitByPattern:
|
|||
changed = changed or do_again
|
||||
return changed
|
||||
|
||||
def hfuse(self, selector):
|
||||
"""Fuse horizontal areas with same input tensor"""
|
||||
def _do_fuse(areas):
|
||||
for i in range(len(areas) - 1):
|
||||
sibling = []
|
||||
for a in areas[i + 1:]:
|
||||
if areas[i].unique_id < a.unique_id:
|
||||
id1, id2 = areas[i].unique_id, a.unique_id
|
||||
else:
|
||||
id1, id2 = a.unique_id, areas[i].unique_id
|
||||
if not self.reach_tab.reachable(id1, id2):
|
||||
sibling.append(a)
|
||||
if sibling:
|
||||
result = selector(areas[i], sibling)
|
||||
if result:
|
||||
fuse_areas = self.limit_area_size(areas[i], result, 64)
|
||||
if not fuse_areas:
|
||||
continue
|
||||
for area in fuse_areas:
|
||||
areas[i].fuse(area)
|
||||
self.set_area_map(area.ops, areas[i])
|
||||
self.areas.remove(area)
|
||||
return True
|
||||
return False
|
||||
changed = False
|
||||
while True:
|
||||
for dom in self.areas:
|
||||
if len(dom.out_relations) > 1 and _do_fuse(list(dom.out_relations.keys())):
|
||||
changed = True
|
||||
break
|
||||
else:
|
||||
break
|
||||
inputs, _ = self.graph.deduce_parameters()
|
||||
while True:
|
||||
for t in inputs:
|
||||
areas = []
|
||||
for op in t.to_ops:
|
||||
a = self.area_map[op]
|
||||
if a in self.areas and a not in areas:
|
||||
areas.append(a)
|
||||
if len(areas) > 1 and _do_fuse(areas):
|
||||
changed = True
|
||||
break
|
||||
else:
|
||||
break
|
||||
return changed
|
||||
|
||||
def fuse_recom(self, selector):
|
||||
"""Fuse recompute area to its user"""
|
||||
for dominant in [self.recom_area, self.recom_user]:
|
||||
|
@ -334,6 +383,7 @@ class GraphSplitByPattern:
|
|||
continue
|
||||
if fuse_areas[0] in [self.recom_area, self.recom_user]:
|
||||
self.recom_user.fuse(self.recom_area)
|
||||
self.set_area_map(self.recom_area.ops, self.recom_user)
|
||||
self.recom_res = True
|
||||
return True
|
||||
return False
|
||||
|
@ -847,6 +897,40 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
return [a], True
|
||||
return None
|
||||
|
||||
def _h_broadcast(dom, sibling):
|
||||
if dom.pattern > PrimLib.BROADCAST:
|
||||
return None
|
||||
fused = []
|
||||
for a in sibling:
|
||||
if a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape:
|
||||
fused.append(a)
|
||||
return fused
|
||||
|
||||
def _h_reduce(dom, sibling):
|
||||
if dom.pattern != PrimLib.REDUCE or dom.stitch_info.stitch_ops:
|
||||
return None
|
||||
dom_op = dom.ops[0]
|
||||
if not PrimLib.is_reduce(dom_op) or _is_atomic_add_available(dom):
|
||||
return None
|
||||
fused = []
|
||||
for a in sibling:
|
||||
op = a.ops[0]
|
||||
if a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \
|
||||
PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \
|
||||
dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis"):
|
||||
fused.append(a)
|
||||
return fused
|
||||
|
||||
def _h_opaque(dom, sibling):
|
||||
if dom.ops[0].prim not in {"StridedSlice"}:
|
||||
return None
|
||||
fused = []
|
||||
for a in sibling:
|
||||
if a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \
|
||||
dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape:
|
||||
fused.append(a)
|
||||
return fused
|
||||
|
||||
def _fuse_loop():
|
||||
changed = True
|
||||
while changed:
|
||||
|
@ -864,6 +948,9 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
if enable_stitch_fusion:
|
||||
changed = self.fuse(_reduce_stitch) or changed
|
||||
self.fuse(_transpose)
|
||||
self.hfuse(_h_broadcast)
|
||||
self.hfuse(_h_reduce)
|
||||
self.hfuse(_h_opaque)
|
||||
|
||||
def _fuse_once(fuse_func):
|
||||
if fuse_func(_reshape) or fuse_func(_elemwise_depth) or fuse_func(_elemwise_width) or \
|
||||
|
|
Loading…
Reference in New Issue