!35220 isolate reshape can only fuse elemwise/broadcast pattern in forward

Merge pull request !35220 from DeshiChen/0530_isolate_reshape
This commit is contained in:
i-robot 2022-06-01 02:12:04 +00:00 committed by Gitee
commit dfca012064
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 15 additions and 10 deletions

View File

@ -106,7 +106,8 @@ class CommonPattern:
if a.mode == GraphSplitByPattern.Area.MODE_COMPOSITE and dom.check_acyclic(a): if a.mode == GraphSplitByPattern.Area.MODE_COMPOSITE and dom.check_acyclic(a):
return [a], False return [a], False
for a, _ in dom.in_relations.items(): for a, _ in dom.in_relations.items():
if a.mode == GraphSplitByPattern.Area.MODE_COMPOSITE and a.check_acyclic(dom): if a.mode == GraphSplitByPattern.Area.MODE_COMPOSITE and a.pattern <= PrimLib.BROADCAST and \
a.check_acyclic(dom):
return [a], True return [a], True
return [] return []
@ -214,6 +215,7 @@ class ReshapeElimChecker:
break break
des_idx, des_prod = des_idx - 1, prod des_idx, des_prod = des_idx - 1, prod
return out_remap return out_remap
def _remap_check(op, remap, iter_type): def _remap_check(op, remap, iter_type):
if iter_type not in (PrimLib.ELEMWISE, PrimLib.BROADCAST): if iter_type not in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
return False return False
@ -222,14 +224,16 @@ class ReshapeElimChecker:
if -i <= len(t.shape) and t.shape[i] != op.output.shape[i]: if -i <= len(t.shape) and t.shape[i] != op.output.shape[i]:
return False return False
return True return True
def push_stack(op, remap): def push_stack(op, remap):
stack.append((op, remap)) stack.append((op, remap))
visited.add(op) visited.add(op)
def _visit_fwd(op, remap): def _visit_fwd(op, remap):
for t in op.inputs: for t in op.inputs:
if t.op is None: if t.op is None:
_visit_bwd(t, remap) _visit_bwd(t, remap)
elif tensor_size(t) > 1 and t.op not in visited: # all broadcast elif tensor_size(t) > 1 and t.op not in visited: # all broadcast
iter_type = PrimLib.iter_type(t.op) iter_type = PrimLib.iter_type(t.op)
if iter_type == PrimLib.RESHAPE: if iter_type == PrimLib.RESHAPE:
new_remap = _propagate(remap, t.shape, t.op.inputs[0].shape) new_remap = _propagate(remap, t.shape, t.op.inputs[0].shape)
@ -238,11 +242,12 @@ class ReshapeElimChecker:
push_stack(t.op, remap) push_stack(t.op, remap)
else: else:
exc_ops.add(t.op) exc_ops.add(t.op)
def _visit_bwd(t, remap): def _visit_bwd(t, remap):
for op in t.to_ops: for op in t.to_ops:
if op not in visited: if op not in visited:
iter_type = PrimLib.iter_type(op) iter_type = PrimLib.iter_type(op)
if iter_type == PrimLib.REDUCE and tensor_size(op.output) == 1: # all reduce if iter_type == PrimLib.REDUCE and tensor_size(op.output) == 1: # all reduce
continue continue
if iter_type == PrimLib.RESHAPE: if iter_type == PrimLib.RESHAPE:
new_remap = _propagate(remap, t.shape, op.output.shape) new_remap = _propagate(remap, t.shape, op.output.shape)
@ -307,7 +312,7 @@ class ReduceOutFuseChecker:
def commit(self, res): def commit(self, res):
""" commit fuse result """ """ commit fuse result """
del res del res
return self.output_excluded # I'm not static return self.output_excluded # I'm not static
class GraphSplitByPattern: class GraphSplitByPattern:
@ -432,6 +437,7 @@ class GraphSplitByPattern:
return False return False
res.append(r) res.append(r)
return True return True
def _commit(a, res): def _commit(a, res):
for i, checker in enumerate(a.checkers): for i, checker in enumerate(a.checkers):
checker.commit(res[i]) checker.commit(res[i])
@ -508,7 +514,6 @@ class GraphSplitByPattern:
"""Get dom op""" """Get dom op"""
return self.ops[0] return self.ops[0]
class RecomputeArea(Area): class RecomputeArea(Area):
"""RecomputeArea""" """RecomputeArea"""
@ -702,7 +707,7 @@ class GraphSplitByPattern:
dom = areas[i] dom = areas[i]
for a in areas[i + 1:]: for a in areas[i + 1:]:
if dom.check_acyclic(a) and a.check_acyclic(dom) and \ if dom.check_acyclic(a) and a.check_acyclic(dom) and \
selector(dom, a) and self.limit_area_size(dom, [a], 64) and dom.fuse_confirm(a): selector(dom, a) and self.limit_area_size(dom, [a], 64) and dom.fuse_confirm(a):
dom.fuse(a) dom.fuse(a)
self.set_area_map(a.ops, dom) self.set_area_map(a.ops, dom)
self.areas.remove(a) self.areas.remove(a)
@ -1167,16 +1172,16 @@ class GraphSplitGpu(GraphSplitByPattern):
"""Fuse rule for injective """ """Fuse rule for injective """
injective_ops = {"Transpose", "StridedSlice"} injective_ops = {"Transpose", "StridedSlice"}
if dom.dom_op().prim not in injective_ops: if dom.dom_op().prim not in injective_ops:
return None return []
to_ops = dom.dom_op().output.to_ops to_ops = dom.dom_op().output.to_ops
if dom.is_output or len(to_ops) != 1 or len(dom.out_relations) != 1: if dom.is_output or len(to_ops) != 1 or len(dom.out_relations) != 1:
return None return []
to_area = list(dom.out_relations.keys())[0] to_area = list(dom.out_relations.keys())[0]
if (to_area.pattern > PrimLib.REDUCE and to_area.dom_op().prim not in injective_ops) or \ if (to_area.pattern > PrimLib.REDUCE and to_area.dom_op().prim not in injective_ops) or \
to_ops[0] not in to_area.ops: to_ops[0] not in to_area.ops:
return None return []
if len(to_area.ops) > self.TRANSPOSE_FUSE_DEPTH: if len(to_area.ops) > self.TRANSPOSE_FUSE_DEPTH:
return None return []
return [to_area], False return [to_area], False
def _h_broadcast(dom, a): def _h_broadcast(dom, a):