From 5c5d125b1d23ff5389a5b9ed5035d60a4213363d Mon Sep 17 00:00:00 2001 From: r1chardf1d0 Date: Thu, 8 Apr 2021 18:29:52 +0800 Subject: [PATCH] optimize stitch fusion strategy --- akg | 2 +- .../graph_kernel/model/graph_split.py | 45 +++++++++++++++++-- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/akg b/akg index e2a30d6b8ec..d91d772a3a9 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit e2a30d6b8ece4a69790ac9e37ae862fe8124ad7c +Subproject commit d91d772a3a913f20eaef6c47517b9ca140edaee2 diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index c0286944b6b..b83691ad38d 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -424,6 +424,41 @@ class GraphSplitGpu(GraphSplitByPattern): fused.append(a) return fused, False + def _stitch_axis(shape): + stitch_axis = [] + size = 1 + for i in shape: + size = size * i + stitch_axis.append(i) + if size >= 1024 * 8: + return stitch_axis + return None + + def _same_stitch_axis(a, b): + x = [] + x.extend(a) + x.extend(b) + stitch_axis = _stitch_axis(x[0].shape) + for item in x: + i_stitch_axis = _stitch_axis(item.shape) + if i_stitch_axis is None or i_stitch_axis != stitch_axis: + return False + return True + + def _may_stitch(dom, a, r): + if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a): + if _reduce_nums(a.ops) < 2: + dom_outs = [op.output for op in dom.ops] + a_ins = [input for op in a.ops for input in op.inputs] + a_outs = [op.output for op in a.ops] + a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins] + stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins] + if _same_stitch_axis(stitch_tensors, a_final_outs): + for tensor in stitch_tensors: + if _tensor_size(tensor) >= 1024 * 1024 * 12: + return True + return False + def _reduce_stitch(dom): if dom.pattern != PrimLib.REDUCE: return None @@ -434,12 +469,14 @@ class GraphSplitGpu(GraphSplitByPattern): fused = [] for a, r in dom.out_relations.items(): - if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a): - if _reduce_nums(a.ops) < 2: - # softmax - if len(a.ops) > 4 and len(a.ops[0].inputs[0].shape) == 4: + if _may_stitch(dom, a, r): + if a.pattern == PrimLib.REDUCE: + if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']: dom.stitch_info.stitch_ops.add(dom.ops[0].output.name) fused.append(a) + elif a.pattern == PrimLib.BROADCAST: + dom.stitch_info.stitch_ops.add(dom.ops[0].output.name) + fused.append(a) return fused, False def _transpose(dom):