!14796 [GRAPH KERNEL]optimize stitch fusion strategy

From: @r1chardf1d0
Reviewed-by: @gaoxiong1,@dylangeng
Signed-off-by: @dylangeng
This commit is contained in:
mindspore-ci-bot 2021-04-25 14:41:49 +08:00 committed by Gitee
commit 7585362148
2 changed files with 42 additions and 5 deletions

2
akg

@ -1 +1 @@
Subproject commit e2a30d6b8ece4a69790ac9e37ae862fe8124ad7c
Subproject commit d91d772a3a913f20eaef6c47517b9ca140edaee2

View File

@ -424,6 +424,41 @@ class GraphSplitGpu(GraphSplitByPattern):
fused.append(a)
return fused, False
def _stitch_axis(shape):
stitch_axis = []
size = 1
for i in shape:
size = size * i
stitch_axis.append(i)
if size >= 1024 * 8:
return stitch_axis
return None
def _same_stitch_axis(a, b):
x = []
x.extend(a)
x.extend(b)
stitch_axis = _stitch_axis(x[0].shape)
for item in x:
i_stitch_axis = _stitch_axis(item.shape)
if i_stitch_axis is None or i_stitch_axis != stitch_axis:
return False
return True
def _may_stitch(dom, a, r):
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
if _reduce_nums(a.ops) < 2:
dom_outs = [op.output for op in dom.ops]
a_ins = [input for op in a.ops for input in op.inputs]
a_outs = [op.output for op in a.ops]
a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins]
stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins]
if _same_stitch_axis(stitch_tensors, a_final_outs):
for tensor in stitch_tensors:
if _tensor_size(tensor) >= 1024 * 1024 * 12:
return True
return False
def _reduce_stitch(dom):
if dom.pattern != PrimLib.REDUCE:
return None
@ -434,12 +469,14 @@ class GraphSplitGpu(GraphSplitByPattern):
fused = []
for a, r in dom.out_relations.items():
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
if _reduce_nums(a.ops) < 2:
# softmax
if len(a.ops) > 4 and len(a.ops[0].inputs[0].shape) == 4:
if _may_stitch(dom, a, r):
if a.pattern == PrimLib.REDUCE:
if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']:
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
fused.append(a)
elif a.pattern == PrimLib.BROADCAST:
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
fused.append(a)
return fused, False
def _transpose(dom):