forked from mindspore-Ecosystem/mindspore
!15658 Reduce recursion overhead of split model
From: @gaoxiong1 Reviewed-by: @anyrenwei,@dylangeng Signed-off-by: @dylangeng
This commit is contained in:
commit
d6f58cb765
|
@ -30,7 +30,7 @@ class GraphSplitByPattern:
|
||||||
self.stitch_ops = set()
|
self.stitch_ops = set()
|
||||||
self.stitch_atomic_ops = set()
|
self.stitch_atomic_ops = set()
|
||||||
|
|
||||||
def __init__(self, init_op, is_output):
|
def __init__(self, init_op, is_output, unique_id):
|
||||||
self.pattern = PrimLib.iter_type(init_op)
|
self.pattern = PrimLib.iter_type(init_op)
|
||||||
self.ops = [init_op]
|
self.ops = [init_op]
|
||||||
self.in_relations = dict() # {area1: relation1, area2: relation2, ...}
|
self.in_relations = dict() # {area1: relation1, area2: relation2, ...}
|
||||||
|
@ -48,6 +48,8 @@ class GraphSplitByPattern:
|
||||||
else:
|
else:
|
||||||
_gather_reduce_exclude(to)
|
_gather_reduce_exclude(to)
|
||||||
_gather_reduce_exclude(init_op)
|
_gather_reduce_exclude(init_op)
|
||||||
|
self.reach_map = dict()
|
||||||
|
self.unique_id = unique_id
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return '<' + '-'.join([op.output.name for op in self.ops]) + '>'
|
return '<' + '-'.join([op.output.name for op in self.ops]) + '>'
|
||||||
|
@ -122,13 +124,20 @@ class GraphSplitByPattern:
|
||||||
if area.output_excluded:
|
if area.output_excluded:
|
||||||
self.output_excluded.update(area.output_excluded)
|
self.output_excluded.update(area.output_excluded)
|
||||||
self.update_stitch_info(area.stitch_info)
|
self.update_stitch_info(area.stitch_info)
|
||||||
|
for to, reach in area.reach_map.items():
|
||||||
|
if reach and not self.reach_map.get(to, False):
|
||||||
|
self.reach_map[to] = True
|
||||||
|
|
||||||
def check_acyclic(self, to):
|
def check_acyclic(self, to):
|
||||||
"""Check circle. It returns false if circle exists"""
|
"""Check circle. It returns false if circle exists"""
|
||||||
def _reached(area, to):
|
def _reached(area, to):
|
||||||
|
if to.unique_id in area.reach_map:
|
||||||
|
return area.reach_map[to.unique_id]
|
||||||
for out, _ in area.out_relations.items():
|
for out, _ in area.out_relations.items():
|
||||||
if out == to or _reached(out, to):
|
if out == to or _reached(out, to):
|
||||||
|
area.reach_map[to.unique_id] = True
|
||||||
return True
|
return True
|
||||||
|
area.reach_map[to.unique_id] = False
|
||||||
return False
|
return False
|
||||||
for out, _ in self.out_relations.items():
|
for out, _ in self.out_relations.items():
|
||||||
if out != to and _reached(out, to):
|
if out != to and _reached(out, to):
|
||||||
|
@ -151,9 +160,11 @@ class GraphSplitByPattern:
|
||||||
self.flags = flags
|
self.flags = flags
|
||||||
area_map = {}
|
area_map = {}
|
||||||
_, outputs = graph.deduce_parameters()
|
_, outputs = graph.deduce_parameters()
|
||||||
|
idx = 0
|
||||||
for op in graph.ops:
|
for op in graph.ops:
|
||||||
is_output = op.output in outputs
|
is_output = op.output in outputs
|
||||||
a = self.Area(op, is_output)
|
a = self.Area(op, is_output, idx)
|
||||||
|
idx += 1
|
||||||
self.set_default_mode(a)
|
self.set_default_mode(a)
|
||||||
self.areas.append(a)
|
self.areas.append(a)
|
||||||
area_map[op] = a
|
area_map[op] = a
|
||||||
|
@ -249,7 +260,7 @@ class GraphSplitByPattern:
|
||||||
break
|
break
|
||||||
if out_reshape_ops:
|
if out_reshape_ops:
|
||||||
for op in out_reshape_ops:
|
for op in out_reshape_ops:
|
||||||
a = self.Area(op, False)
|
a = self.Area(op, False, -1)
|
||||||
self.set_default_mode(a)
|
self.set_default_mode(a)
|
||||||
new_areas.append(a)
|
new_areas.append(a)
|
||||||
area.ops = remain_ops
|
area.ops = remain_ops
|
||||||
|
|
|
@ -448,45 +448,26 @@ class Graph:
|
||||||
class GraphVisitor:
|
class GraphVisitor:
|
||||||
"""Graph visitor"""
|
"""Graph visitor"""
|
||||||
|
|
||||||
def __init__(self, forward=True, once_mode=True):
|
def __init__(self, forward=True):
|
||||||
self.forward = forward
|
self.forward = forward
|
||||||
self.once_mode = once_mode
|
|
||||||
if self.once_mode:
|
|
||||||
self.visited = set()
|
|
||||||
|
|
||||||
def visit_graph(self, graph):
|
def visit_graph(self, graph):
|
||||||
"""Visit graph"""
|
"""Visit graph"""
|
||||||
inputs, outputs = graph.deduce_parameters()
|
|
||||||
if self.forward:
|
if self.forward:
|
||||||
for tensor in inputs:
|
for op in graph.ops:
|
||||||
for op in tensor.to_ops:
|
self.visit(op)
|
||||||
self.visit(op)
|
|
||||||
else:
|
else:
|
||||||
for tensor in outputs:
|
for i in range(len(graph.ops)-1, -1, -1):
|
||||||
if not tensor.to_ops:
|
self.visit(graph.ops[i])
|
||||||
self.visit(tensor.op)
|
|
||||||
|
|
||||||
def visit(self, op):
|
|
||||||
"""Visit op"""
|
|
||||||
next_ops = op.output.to_ops if self.forward else [
|
|
||||||
t.op for t in op.inputs if t.op is not None]
|
|
||||||
if self.once_mode:
|
|
||||||
self.visited.add(op)
|
|
||||||
for n in next_ops:
|
|
||||||
if n not in self.visited:
|
|
||||||
self.visit(n)
|
|
||||||
else:
|
|
||||||
for n in next_ops:
|
|
||||||
self.visit(n)
|
|
||||||
|
|
||||||
|
|
||||||
class AlignShape(GraphVisitor):
|
class AlignShape(GraphVisitor):
|
||||||
"""Align shape"""
|
"""Align shape"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(once_mode=False)
|
super().__init__()
|
||||||
|
|
||||||
def visit(self, op):
|
def visit(self, op):
|
||||||
|
"""Visit op node"""
|
||||||
prim = PrimLib.get_prim(op)
|
prim = PrimLib.get_prim(op)
|
||||||
if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE):
|
if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE):
|
||||||
out_dim = len(op.output.shape)
|
out_dim = len(op.output.shape)
|
||||||
|
@ -496,8 +477,6 @@ class AlignShape(GraphVisitor):
|
||||||
align_dim = len(t.shape)
|
align_dim = len(t.shape)
|
||||||
if align_dim > out_dim:
|
if align_dim > out_dim:
|
||||||
op.output.shape = [1] * (align_dim - out_dim) + op.output.shape
|
op.output.shape = [1] * (align_dim - out_dim) + op.output.shape
|
||||||
super().visit(op)
|
|
||||||
|
|
||||||
|
|
||||||
class AddControlBuddy(GraphVisitor):
|
class AddControlBuddy(GraphVisitor):
|
||||||
"""Add control buddy"""
|
"""Add control buddy"""
|
||||||
|
@ -507,6 +486,7 @@ class AddControlBuddy(GraphVisitor):
|
||||||
self.buddies = {} # {op : [ctrl_op]}
|
self.buddies = {} # {op : [ctrl_op]}
|
||||||
|
|
||||||
def visit(self, op):
|
def visit(self, op):
|
||||||
|
"""Visit op node"""
|
||||||
if op.prim == "MakeTuple":
|
if op.prim == "MakeTuple":
|
||||||
assert len(op.output.to_ops) == 1
|
assert len(op.output.to_ops) == 1
|
||||||
owner = op.output.to_ops[0]
|
owner = op.output.to_ops[0]
|
||||||
|
@ -517,9 +497,9 @@ class AddControlBuddy(GraphVisitor):
|
||||||
if op in self.buddies:
|
if op in self.buddies:
|
||||||
ops = self.buddies.pop(op)
|
ops = self.buddies.pop(op)
|
||||||
self.buddies[owner].extend(ops)
|
self.buddies[owner].extend(ops)
|
||||||
super().visit(op)
|
|
||||||
|
|
||||||
def visit_graph(self, graph):
|
def visit_graph(self, graph):
|
||||||
|
"""Visit graph nodes"""
|
||||||
super().visit_graph(graph)
|
super().visit_graph(graph)
|
||||||
for owner in self.buddies:
|
for owner in self.buddies:
|
||||||
for op in self.buddies[owner]:
|
for op in self.buddies[owner]:
|
||||||
|
|
Loading…
Reference in New Issue