!35220 isolate reshape can only fuse elemwise/broadcast pattern in forward
Merge pull request !35220 from DeshiChen/0530_isolate_reshape
This commit is contained in:
commit
dfca012064
|
@ -106,7 +106,8 @@ class CommonPattern:
|
||||||
if a.mode == GraphSplitByPattern.Area.MODE_COMPOSITE and dom.check_acyclic(a):
|
if a.mode == GraphSplitByPattern.Area.MODE_COMPOSITE and dom.check_acyclic(a):
|
||||||
return [a], False
|
return [a], False
|
||||||
for a, _ in dom.in_relations.items():
|
for a, _ in dom.in_relations.items():
|
||||||
if a.mode == GraphSplitByPattern.Area.MODE_COMPOSITE and a.check_acyclic(dom):
|
if a.mode == GraphSplitByPattern.Area.MODE_COMPOSITE and a.pattern <= PrimLib.BROADCAST and \
|
||||||
|
a.check_acyclic(dom):
|
||||||
return [a], True
|
return [a], True
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@ -214,6 +215,7 @@ class ReshapeElimChecker:
|
||||||
break
|
break
|
||||||
des_idx, des_prod = des_idx - 1, prod
|
des_idx, des_prod = des_idx - 1, prod
|
||||||
return out_remap
|
return out_remap
|
||||||
|
|
||||||
def _remap_check(op, remap, iter_type):
|
def _remap_check(op, remap, iter_type):
|
||||||
if iter_type not in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
|
if iter_type not in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
|
||||||
return False
|
return False
|
||||||
|
@ -222,14 +224,16 @@ class ReshapeElimChecker:
|
||||||
if -i <= len(t.shape) and t.shape[i] != op.output.shape[i]:
|
if -i <= len(t.shape) and t.shape[i] != op.output.shape[i]:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def push_stack(op, remap):
|
def push_stack(op, remap):
|
||||||
stack.append((op, remap))
|
stack.append((op, remap))
|
||||||
visited.add(op)
|
visited.add(op)
|
||||||
|
|
||||||
def _visit_fwd(op, remap):
|
def _visit_fwd(op, remap):
|
||||||
for t in op.inputs:
|
for t in op.inputs:
|
||||||
if t.op is None:
|
if t.op is None:
|
||||||
_visit_bwd(t, remap)
|
_visit_bwd(t, remap)
|
||||||
elif tensor_size(t) > 1 and t.op not in visited: # all broadcast
|
elif tensor_size(t) > 1 and t.op not in visited: # all broadcast
|
||||||
iter_type = PrimLib.iter_type(t.op)
|
iter_type = PrimLib.iter_type(t.op)
|
||||||
if iter_type == PrimLib.RESHAPE:
|
if iter_type == PrimLib.RESHAPE:
|
||||||
new_remap = _propagate(remap, t.shape, t.op.inputs[0].shape)
|
new_remap = _propagate(remap, t.shape, t.op.inputs[0].shape)
|
||||||
|
@ -238,11 +242,12 @@ class ReshapeElimChecker:
|
||||||
push_stack(t.op, remap)
|
push_stack(t.op, remap)
|
||||||
else:
|
else:
|
||||||
exc_ops.add(t.op)
|
exc_ops.add(t.op)
|
||||||
|
|
||||||
def _visit_bwd(t, remap):
|
def _visit_bwd(t, remap):
|
||||||
for op in t.to_ops:
|
for op in t.to_ops:
|
||||||
if op not in visited:
|
if op not in visited:
|
||||||
iter_type = PrimLib.iter_type(op)
|
iter_type = PrimLib.iter_type(op)
|
||||||
if iter_type == PrimLib.REDUCE and tensor_size(op.output) == 1: # all reduce
|
if iter_type == PrimLib.REDUCE and tensor_size(op.output) == 1: # all reduce
|
||||||
continue
|
continue
|
||||||
if iter_type == PrimLib.RESHAPE:
|
if iter_type == PrimLib.RESHAPE:
|
||||||
new_remap = _propagate(remap, t.shape, op.output.shape)
|
new_remap = _propagate(remap, t.shape, op.output.shape)
|
||||||
|
@ -307,7 +312,7 @@ class ReduceOutFuseChecker:
|
||||||
def commit(self, res):
|
def commit(self, res):
|
||||||
""" commit fuse result """
|
""" commit fuse result """
|
||||||
del res
|
del res
|
||||||
return self.output_excluded # I'm not static
|
return self.output_excluded # I'm not static
|
||||||
|
|
||||||
|
|
||||||
class GraphSplitByPattern:
|
class GraphSplitByPattern:
|
||||||
|
@ -432,6 +437,7 @@ class GraphSplitByPattern:
|
||||||
return False
|
return False
|
||||||
res.append(r)
|
res.append(r)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _commit(a, res):
|
def _commit(a, res):
|
||||||
for i, checker in enumerate(a.checkers):
|
for i, checker in enumerate(a.checkers):
|
||||||
checker.commit(res[i])
|
checker.commit(res[i])
|
||||||
|
@ -508,7 +514,6 @@ class GraphSplitByPattern:
|
||||||
"""Get dom op"""
|
"""Get dom op"""
|
||||||
return self.ops[0]
|
return self.ops[0]
|
||||||
|
|
||||||
|
|
||||||
class RecomputeArea(Area):
|
class RecomputeArea(Area):
|
||||||
"""RecomputeArea"""
|
"""RecomputeArea"""
|
||||||
|
|
||||||
|
@ -702,7 +707,7 @@ class GraphSplitByPattern:
|
||||||
dom = areas[i]
|
dom = areas[i]
|
||||||
for a in areas[i + 1:]:
|
for a in areas[i + 1:]:
|
||||||
if dom.check_acyclic(a) and a.check_acyclic(dom) and \
|
if dom.check_acyclic(a) and a.check_acyclic(dom) and \
|
||||||
selector(dom, a) and self.limit_area_size(dom, [a], 64) and dom.fuse_confirm(a):
|
selector(dom, a) and self.limit_area_size(dom, [a], 64) and dom.fuse_confirm(a):
|
||||||
dom.fuse(a)
|
dom.fuse(a)
|
||||||
self.set_area_map(a.ops, dom)
|
self.set_area_map(a.ops, dom)
|
||||||
self.areas.remove(a)
|
self.areas.remove(a)
|
||||||
|
@ -1167,16 +1172,16 @@ class GraphSplitGpu(GraphSplitByPattern):
|
||||||
"""Fuse rule for injective """
|
"""Fuse rule for injective """
|
||||||
injective_ops = {"Transpose", "StridedSlice"}
|
injective_ops = {"Transpose", "StridedSlice"}
|
||||||
if dom.dom_op().prim not in injective_ops:
|
if dom.dom_op().prim not in injective_ops:
|
||||||
return None
|
return []
|
||||||
to_ops = dom.dom_op().output.to_ops
|
to_ops = dom.dom_op().output.to_ops
|
||||||
if dom.is_output or len(to_ops) != 1 or len(dom.out_relations) != 1:
|
if dom.is_output or len(to_ops) != 1 or len(dom.out_relations) != 1:
|
||||||
return None
|
return []
|
||||||
to_area = list(dom.out_relations.keys())[0]
|
to_area = list(dom.out_relations.keys())[0]
|
||||||
if (to_area.pattern > PrimLib.REDUCE and to_area.dom_op().prim not in injective_ops) or \
|
if (to_area.pattern > PrimLib.REDUCE and to_area.dom_op().prim not in injective_ops) or \
|
||||||
to_ops[0] not in to_area.ops:
|
to_ops[0] not in to_area.ops:
|
||||||
return None
|
return []
|
||||||
if len(to_area.ops) > self.TRANSPOSE_FUSE_DEPTH:
|
if len(to_area.ops) > self.TRANSPOSE_FUSE_DEPTH:
|
||||||
return None
|
return []
|
||||||
return [to_area], False
|
return [to_area], False
|
||||||
|
|
||||||
def _h_broadcast(dom, a):
|
def _h_broadcast(dom, a):
|
||||||
|
|
Loading…
Reference in New Issue