!15658 Reduce recursion overhead of split model

From: @gaoxiong1
Reviewed-by: @anyrenwei,@dylangeng
Signed-off-by: @dylangeng
This commit is contained in:
mindspore-ci-bot 2021-04-26 16:39:04 +08:00 committed by Gitee
commit d6f58cb765
2 changed files with 23 additions and 32 deletions

View File

@ -30,7 +30,7 @@ class GraphSplitByPattern:
self.stitch_ops = set() self.stitch_ops = set()
self.stitch_atomic_ops = set() self.stitch_atomic_ops = set()
def __init__(self, init_op, is_output): def __init__(self, init_op, is_output, unique_id):
self.pattern = PrimLib.iter_type(init_op) self.pattern = PrimLib.iter_type(init_op)
self.ops = [init_op] self.ops = [init_op]
self.in_relations = dict() # {area1: relation1, area2: relation2, ...} self.in_relations = dict() # {area1: relation1, area2: relation2, ...}
@ -48,6 +48,8 @@ class GraphSplitByPattern:
else: else:
_gather_reduce_exclude(to) _gather_reduce_exclude(to)
_gather_reduce_exclude(init_op) _gather_reduce_exclude(init_op)
self.reach_map = dict()
self.unique_id = unique_id
def __str__(self): def __str__(self):
return '<' + '-'.join([op.output.name for op in self.ops]) + '>' return '<' + '-'.join([op.output.name for op in self.ops]) + '>'
@ -122,13 +124,20 @@ class GraphSplitByPattern:
if area.output_excluded: if area.output_excluded:
self.output_excluded.update(area.output_excluded) self.output_excluded.update(area.output_excluded)
self.update_stitch_info(area.stitch_info) self.update_stitch_info(area.stitch_info)
for to, reach in area.reach_map.items():
if reach and not self.reach_map.get(to, False):
self.reach_map[to] = True
def check_acyclic(self, to): def check_acyclic(self, to):
"""Check circle. It returns false if circle exists""" """Check circle. It returns false if circle exists"""
def _reached(area, to): def _reached(area, to):
if to.unique_id in area.reach_map:
return area.reach_map[to.unique_id]
for out, _ in area.out_relations.items(): for out, _ in area.out_relations.items():
if out == to or _reached(out, to): if out == to or _reached(out, to):
area.reach_map[to.unique_id] = True
return True return True
area.reach_map[to.unique_id] = False
return False return False
for out, _ in self.out_relations.items(): for out, _ in self.out_relations.items():
if out != to and _reached(out, to): if out != to and _reached(out, to):
@ -151,9 +160,11 @@ class GraphSplitByPattern:
self.flags = flags self.flags = flags
area_map = {} area_map = {}
_, outputs = graph.deduce_parameters() _, outputs = graph.deduce_parameters()
idx = 0
for op in graph.ops: for op in graph.ops:
is_output = op.output in outputs is_output = op.output in outputs
a = self.Area(op, is_output) a = self.Area(op, is_output, idx)
idx += 1
self.set_default_mode(a) self.set_default_mode(a)
self.areas.append(a) self.areas.append(a)
area_map[op] = a area_map[op] = a
@ -249,7 +260,7 @@ class GraphSplitByPattern:
break break
if out_reshape_ops: if out_reshape_ops:
for op in out_reshape_ops: for op in out_reshape_ops:
a = self.Area(op, False) a = self.Area(op, False, -1)
self.set_default_mode(a) self.set_default_mode(a)
new_areas.append(a) new_areas.append(a)
area.ops = remain_ops area.ops = remain_ops

View File

@ -448,45 +448,26 @@ class Graph:
class GraphVisitor: class GraphVisitor:
"""Graph visitor""" """Graph visitor"""
def __init__(self, forward=True, once_mode=True): def __init__(self, forward=True):
self.forward = forward self.forward = forward
self.once_mode = once_mode
if self.once_mode:
self.visited = set()
def visit_graph(self, graph): def visit_graph(self, graph):
"""Visit graph""" """Visit graph"""
inputs, outputs = graph.deduce_parameters()
if self.forward: if self.forward:
for tensor in inputs: for op in graph.ops:
for op in tensor.to_ops: self.visit(op)
self.visit(op)
else: else:
for tensor in outputs: for i in range(len(graph.ops)-1, -1, -1):
if not tensor.to_ops: self.visit(graph.ops[i])
self.visit(tensor.op)
def visit(self, op):
"""Visit op"""
next_ops = op.output.to_ops if self.forward else [
t.op for t in op.inputs if t.op is not None]
if self.once_mode:
self.visited.add(op)
for n in next_ops:
if n not in self.visited:
self.visit(n)
else:
for n in next_ops:
self.visit(n)
class AlignShape(GraphVisitor): class AlignShape(GraphVisitor):
"""Align shape""" """Align shape"""
def __init__(self): def __init__(self):
super().__init__(once_mode=False) super().__init__()
def visit(self, op): def visit(self, op):
"""Visit op node"""
prim = PrimLib.get_prim(op) prim = PrimLib.get_prim(op)
if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE): if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE):
out_dim = len(op.output.shape) out_dim = len(op.output.shape)
@ -496,8 +477,6 @@ class AlignShape(GraphVisitor):
align_dim = len(t.shape) align_dim = len(t.shape)
if align_dim > out_dim: if align_dim > out_dim:
op.output.shape = [1] * (align_dim - out_dim) + op.output.shape op.output.shape = [1] * (align_dim - out_dim) + op.output.shape
super().visit(op)
class AddControlBuddy(GraphVisitor): class AddControlBuddy(GraphVisitor):
"""Add control buddy""" """Add control buddy"""
@ -507,6 +486,7 @@ class AddControlBuddy(GraphVisitor):
self.buddies = {} # {op : [ctrl_op]} self.buddies = {} # {op : [ctrl_op]}
def visit(self, op): def visit(self, op):
"""Visit op node"""
if op.prim == "MakeTuple": if op.prim == "MakeTuple":
assert len(op.output.to_ops) == 1 assert len(op.output.to_ops) == 1
owner = op.output.to_ops[0] owner = op.output.to_ops[0]
@ -517,9 +497,9 @@ class AddControlBuddy(GraphVisitor):
if op in self.buddies: if op in self.buddies:
ops = self.buddies.pop(op) ops = self.buddies.pop(op)
self.buddies[owner].extend(ops) self.buddies[owner].extend(ops)
super().visit(op)
def visit_graph(self, graph): def visit_graph(self, graph):
"""Visit graph nodes"""
super().visit_graph(graph) super().visit_graph(graph)
for owner in self.buddies: for owner in self.buddies:
for op in self.buddies[owner]: for op in self.buddies[owner]: