forked from mindspore-Ecosystem/mindspore
eliminate recursion call
This commit is contained in:
parent
7b91d67907
commit
4bc67f38de
|
@ -19,6 +19,34 @@ from .model import PrimLib, Graph, Tensor
|
||||||
|
|
||||||
class GraphSplitByPattern:
|
class GraphSplitByPattern:
|
||||||
"""Graph splitter"""
|
"""Graph splitter"""
|
||||||
|
class ReachTable:
|
||||||
|
"""Reachable table"""
|
||||||
|
def __init__(self, size):
|
||||||
|
self.map = []
|
||||||
|
self.alive = set(range(size))
|
||||||
|
for i in range(0, size):
|
||||||
|
self.map.append([False for j in range(0, size)])
|
||||||
|
self.map[i][i] = True
|
||||||
|
|
||||||
|
def reachable(self, x, y):
|
||||||
|
"""reachable from x to y"""
|
||||||
|
return self.map[x][y]
|
||||||
|
|
||||||
|
def sync(self, x, y):
|
||||||
|
"""sync from y to x"""
|
||||||
|
for i in self.alive:
|
||||||
|
if self.map[y][i] and not self.map[x][i]:
|
||||||
|
self.map[x][i] = True
|
||||||
|
|
||||||
|
def fuse(self, x, y):
|
||||||
|
"""fuse y to x"""
|
||||||
|
for i in self.alive:
|
||||||
|
if self.map[y][i] and not self.map[x][i]:
|
||||||
|
self.map[x][i] = True
|
||||||
|
if self.map[i][y] and not self.map[i][x]:
|
||||||
|
self.map[i][x] = True
|
||||||
|
self.alive.remove(y)
|
||||||
|
|
||||||
class Area:
|
class Area:
|
||||||
"""Area"""
|
"""Area"""
|
||||||
MODE_BASIC = 1
|
MODE_BASIC = 1
|
||||||
|
@ -30,7 +58,7 @@ class GraphSplitByPattern:
|
||||||
self.stitch_ops = set()
|
self.stitch_ops = set()
|
||||||
self.stitch_atomic_ops = set()
|
self.stitch_atomic_ops = set()
|
||||||
|
|
||||||
def __init__(self, init_op, is_output, unique_id):
|
def __init__(self, init_op, is_output, unique_id, reach_tab):
|
||||||
self.pattern = PrimLib.iter_type(init_op)
|
self.pattern = PrimLib.iter_type(init_op)
|
||||||
self.ops = [init_op]
|
self.ops = [init_op]
|
||||||
self.in_relations = dict() # {area1: relation1, area2: relation2, ...}
|
self.in_relations = dict() # {area1: relation1, area2: relation2, ...}
|
||||||
|
@ -48,8 +76,8 @@ class GraphSplitByPattern:
|
||||||
else:
|
else:
|
||||||
_gather_reduce_exclude(to)
|
_gather_reduce_exclude(to)
|
||||||
_gather_reduce_exclude(init_op)
|
_gather_reduce_exclude(init_op)
|
||||||
self.reach_map = dict()
|
|
||||||
self.unique_id = unique_id
|
self.unique_id = unique_id
|
||||||
|
self.reach_tab = reach_tab
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return '<' + '-'.join([op.output.name for op in self.ops]) + '>'
|
return '<' + '-'.join([op.output.name for op in self.ops]) + '>'
|
||||||
|
@ -78,6 +106,8 @@ class GraphSplitByPattern:
|
||||||
"""Link outputs"""
|
"""Link outputs"""
|
||||||
for input_area, r in self.in_relations.items():
|
for input_area, r in self.in_relations.items():
|
||||||
input_area.out_relations[self] = r
|
input_area.out_relations[self] = r
|
||||||
|
for out, _ in self.out_relations.items():
|
||||||
|
self.reach_tab.sync(self.unique_id, out.unique_id)
|
||||||
|
|
||||||
def update_stitch_info(self, stitch_info):
|
def update_stitch_info(self, stitch_info):
|
||||||
if stitch_info.stitch_ops:
|
if stitch_info.stitch_ops:
|
||||||
|
@ -124,23 +154,12 @@ class GraphSplitByPattern:
|
||||||
if area.output_excluded:
|
if area.output_excluded:
|
||||||
self.output_excluded.update(area.output_excluded)
|
self.output_excluded.update(area.output_excluded)
|
||||||
self.update_stitch_info(area.stitch_info)
|
self.update_stitch_info(area.stitch_info)
|
||||||
for to, reach in area.reach_map.items():
|
self.reach_tab.fuse(self.unique_id, area.unique_id)
|
||||||
if reach and not self.reach_map.get(to, False):
|
|
||||||
self.reach_map[to] = True
|
|
||||||
|
|
||||||
def check_acyclic(self, to):
|
def check_acyclic(self, to):
|
||||||
"""Check circle. It returns false if circle exists"""
|
"""Check circle. It returns false if circle exists"""
|
||||||
def _reached(area, to):
|
|
||||||
if to.unique_id in area.reach_map:
|
|
||||||
return area.reach_map[to.unique_id]
|
|
||||||
for out, _ in area.out_relations.items():
|
|
||||||
if out == to or _reached(out, to):
|
|
||||||
area.reach_map[to.unique_id] = True
|
|
||||||
return True
|
|
||||||
area.reach_map[to.unique_id] = False
|
|
||||||
return False
|
|
||||||
for out, _ in self.out_relations.items():
|
for out, _ in self.out_relations.items():
|
||||||
if out != to and _reached(out, to):
|
if out != to and self.reach_tab.reachable(out.unique_id, to.unique_id):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -158,20 +177,21 @@ class GraphSplitByPattern:
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
self.areas = []
|
self.areas = []
|
||||||
self.flags = flags
|
self.flags = flags
|
||||||
|
self.reach_tab = self.ReachTable(len(graph.ops))
|
||||||
area_map = {}
|
area_map = {}
|
||||||
_, outputs = graph.deduce_parameters()
|
_, outputs = graph.deduce_parameters()
|
||||||
idx = 0
|
idx = 0
|
||||||
for op in graph.ops:
|
for op in graph.ops:
|
||||||
is_output = op.output in outputs
|
is_output = op.output in outputs
|
||||||
a = self.Area(op, is_output, idx)
|
a = self.Area(op, is_output, idx, self.reach_tab)
|
||||||
idx += 1
|
idx += 1
|
||||||
self.set_default_mode(a)
|
self.set_default_mode(a)
|
||||||
self.areas.append(a)
|
self.areas.append(a)
|
||||||
area_map[op] = a
|
area_map[op] = a
|
||||||
for a in self.areas:
|
for a in self.areas:
|
||||||
a.link_input(area_map)
|
a.link_input(area_map)
|
||||||
for a in self.areas:
|
for i in range(len(self.areas)-1, -1, -1):
|
||||||
a.link_output()
|
self.areas[i].link_output()
|
||||||
|
|
||||||
def set_default_mode(self, area):
|
def set_default_mode(self, area):
|
||||||
area.mode = self.get_default_mode(area.ops[0])
|
area.mode = self.get_default_mode(area.ops[0])
|
||||||
|
@ -260,7 +280,7 @@ class GraphSplitByPattern:
|
||||||
break
|
break
|
||||||
if out_reshape_ops:
|
if out_reshape_ops:
|
||||||
for op in out_reshape_ops:
|
for op in out_reshape_ops:
|
||||||
a = self.Area(op, False, -1)
|
a = self.Area(op, False, 0, self.reach_tab)
|
||||||
self.set_default_mode(a)
|
self.set_default_mode(a)
|
||||||
new_areas.append(a)
|
new_areas.append(a)
|
||||||
area.ops = remain_ops
|
area.ops = remain_ops
|
||||||
|
|
Loading…
Reference in New Issue