From 71002ed19d7f7fa217d26ed3db03070855ddec41 Mon Sep 17 00:00:00 2001 From: Gaoxiong Date: Sun, 25 Apr 2021 21:06:26 +0800 Subject: [PATCH] reduce recursion overhead of split model --- .../graph_kernel/model/graph_split.py | 17 +++++++-- .../_extends/graph_kernel/model/model.py | 38 +++++-------------- 2 files changed, 23 insertions(+), 32 deletions(-) diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index b83691ad38d..ec102ecbffa 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -30,7 +30,7 @@ class GraphSplitByPattern: self.stitch_ops = set() self.stitch_atomic_ops = set() - def __init__(self, init_op, is_output): + def __init__(self, init_op, is_output, unique_id): self.pattern = PrimLib.iter_type(init_op) self.ops = [init_op] self.in_relations = dict() # {area1: relation1, area2: relation2, ...} @@ -48,6 +48,8 @@ class GraphSplitByPattern: else: _gather_reduce_exclude(to) _gather_reduce_exclude(init_op) + self.reach_map = dict() + self.unique_id = unique_id def __str__(self): return '<' + '-'.join([op.output.name for op in self.ops]) + '>' @@ -122,13 +124,20 @@ class GraphSplitByPattern: if area.output_excluded: self.output_excluded.update(area.output_excluded) self.update_stitch_info(area.stitch_info) + for to, reach in area.reach_map.items(): + if reach and not self.reach_map.get(to, False): + self.reach_map[to] = True def check_acyclic(self, to): """Check circle. It returns false if circle exists""" def _reached(area, to): + if to.unique_id in area.reach_map: + return area.reach_map[to.unique_id] for out, _ in area.out_relations.items(): if out == to or _reached(out, to): + area.reach_map[to.unique_id] = True return True + area.reach_map[to.unique_id] = False return False for out, _ in self.out_relations.items(): if out != to and _reached(out, to): @@ -151,9 +160,11 @@ class GraphSplitByPattern: self.flags = flags area_map = {} _, outputs = graph.deduce_parameters() + idx = 0 for op in graph.ops: is_output = op.output in outputs - a = self.Area(op, is_output) + a = self.Area(op, is_output, idx) + idx += 1 self.set_default_mode(a) self.areas.append(a) area_map[op] = a @@ -249,7 +260,7 @@ class GraphSplitByPattern: break if out_reshape_ops: for op in out_reshape_ops: - a = self.Area(op, False) + a = self.Area(op, False, -1) self.set_default_mode(a) new_areas.append(a) area.ops = remain_ops diff --git a/mindspore/_extends/graph_kernel/model/model.py b/mindspore/_extends/graph_kernel/model/model.py index 359001b08a3..5fc0d17b1ea 100644 --- a/mindspore/_extends/graph_kernel/model/model.py +++ b/mindspore/_extends/graph_kernel/model/model.py @@ -448,45 +448,26 @@ class Graph: class GraphVisitor: """Graph visitor""" - def __init__(self, forward=True, once_mode=True): + def __init__(self, forward=True): self.forward = forward - self.once_mode = once_mode - if self.once_mode: - self.visited = set() def visit_graph(self, graph): """Visit graph""" - inputs, outputs = graph.deduce_parameters() if self.forward: - for tensor in inputs: - for op in tensor.to_ops: - self.visit(op) + for op in graph.ops: + self.visit(op) else: - for tensor in outputs: - if not tensor.to_ops: - self.visit(tensor.op) - - def visit(self, op): - """Visit op""" - next_ops = op.output.to_ops if self.forward else [ - t.op for t in op.inputs if t.op is not None] - if self.once_mode: - self.visited.add(op) - for n in next_ops: - if n not in self.visited: - self.visit(n) - else: - for n in next_ops: - self.visit(n) - + for i in range(len(graph.ops)-1, -1, -1): + self.visit(graph.ops[i]) class AlignShape(GraphVisitor): """Align shape""" def __init__(self): - super().__init__(once_mode=False) + super().__init__() def visit(self, op): + """Visit op node""" prim = PrimLib.get_prim(op) if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE): out_dim = len(op.output.shape) @@ -496,8 +477,6 @@ class AlignShape(GraphVisitor): align_dim = len(t.shape) if align_dim > out_dim: op.output.shape = [1] * (align_dim - out_dim) + op.output.shape - super().visit(op) - class AddControlBuddy(GraphVisitor): """Add control buddy""" @@ -507,6 +486,7 @@ class AddControlBuddy(GraphVisitor): self.buddies = {} # {op : [ctrl_op]} def visit(self, op): + """Visit op node""" if op.prim == "MakeTuple": assert len(op.output.to_ops) == 1 owner = op.output.to_ops[0] @@ -517,9 +497,9 @@ class AddControlBuddy(GraphVisitor): if op in self.buddies: ops = self.buddies.pop(op) self.buddies[owner].extend(ops) - super().visit(op) def visit_graph(self, graph): + """Visit graph nodes""" super().visit_graph(graph) for owner in self.buddies: for op in self.buddies[owner]: