From 4bc67f38ded7db00cd36cdab37598ea7130d25e0 Mon Sep 17 00:00:00 2001 From: Gaoxiong Date: Fri, 30 Apr 2021 15:57:29 +0800 Subject: [PATCH] eliminate recursion call --- .../graph_kernel/model/graph_split.py | 58 +++++++++++++------ 1 file changed, 39 insertions(+), 19 deletions(-) diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index ec102ecbffa..c2b67eafceb 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -19,6 +19,34 @@ from .model import PrimLib, Graph, Tensor class GraphSplitByPattern: """Graph splitter""" + class ReachTable: + """Reachable table""" + def __init__(self, size): + self.map = [] + self.alive = set(range(size)) + for i in range(0, size): + self.map.append([False for j in range(0, size)]) + self.map[i][i] = True + + def reachable(self, x, y): + """reachable from x to y""" + return self.map[x][y] + + def sync(self, x, y): + """sync from y to x""" + for i in self.alive: + if self.map[y][i] and not self.map[x][i]: + self.map[x][i] = True + + def fuse(self, x, y): + """fuse y to x""" + for i in self.alive: + if self.map[y][i] and not self.map[x][i]: + self.map[x][i] = True + if self.map[i][y] and not self.map[i][x]: + self.map[i][x] = True + self.alive.remove(y) + class Area: """Area""" MODE_BASIC = 1 @@ -30,7 +58,7 @@ class GraphSplitByPattern: self.stitch_ops = set() self.stitch_atomic_ops = set() - def __init__(self, init_op, is_output, unique_id): + def __init__(self, init_op, is_output, unique_id, reach_tab): self.pattern = PrimLib.iter_type(init_op) self.ops = [init_op] self.in_relations = dict() # {area1: relation1, area2: relation2, ...} @@ -48,8 +76,8 @@ class GraphSplitByPattern: else: _gather_reduce_exclude(to) _gather_reduce_exclude(init_op) - self.reach_map = dict() self.unique_id = unique_id + self.reach_tab = reach_tab def __str__(self): return '<' + '-'.join([op.output.name for op in self.ops]) + '>' @@ -78,6 +106,8 @@ class GraphSplitByPattern: """Link outputs""" for input_area, r in self.in_relations.items(): input_area.out_relations[self] = r + for out, _ in self.out_relations.items(): + self.reach_tab.sync(self.unique_id, out.unique_id) def update_stitch_info(self, stitch_info): if stitch_info.stitch_ops: @@ -124,23 +154,12 @@ class GraphSplitByPattern: if area.output_excluded: self.output_excluded.update(area.output_excluded) self.update_stitch_info(area.stitch_info) - for to, reach in area.reach_map.items(): - if reach and not self.reach_map.get(to, False): - self.reach_map[to] = True + self.reach_tab.fuse(self.unique_id, area.unique_id) def check_acyclic(self, to): """Check circle. It returns false if circle exists""" - def _reached(area, to): - if to.unique_id in area.reach_map: - return area.reach_map[to.unique_id] - for out, _ in area.out_relations.items(): - if out == to or _reached(out, to): - area.reach_map[to.unique_id] = True - return True - area.reach_map[to.unique_id] = False - return False for out, _ in self.out_relations.items(): - if out != to and _reached(out, to): + if out != to and self.reach_tab.reachable(out.unique_id, to.unique_id): return False return True @@ -158,20 +177,21 @@ class GraphSplitByPattern: self.graph = graph self.areas = [] self.flags = flags + self.reach_tab = self.ReachTable(len(graph.ops)) area_map = {} _, outputs = graph.deduce_parameters() idx = 0 for op in graph.ops: is_output = op.output in outputs - a = self.Area(op, is_output, idx) + a = self.Area(op, is_output, idx, self.reach_tab) idx += 1 self.set_default_mode(a) self.areas.append(a) area_map[op] = a for a in self.areas: a.link_input(area_map) - for a in self.areas: - a.link_output() + for i in range(len(self.areas)-1, -1, -1): + self.areas[i].link_output() def set_default_mode(self, area): area.mode = self.get_default_mode(area.ops[0]) @@ -260,7 +280,7 @@ class GraphSplitByPattern: break if out_reshape_ops: for op in out_reshape_ops: - a = self.Area(op, False, -1) + a = self.Area(op, False, 0, self.reach_tab) self.set_default_mode(a) new_areas.append(a) area.ops = remain_ops