forked from mindspore-Ecosystem/mindspore
!14796 [GRAPH KERNEL]optimize stitch fusion strategy
From: @r1chardf1d0 Reviewed-by: @gaoxiong1,@dylangeng Signed-off-by: @dylangeng
This commit is contained in:
commit
7585362148
2
akg
2
akg
|
@ -1 +1 @@
|
|||
Subproject commit e2a30d6b8ece4a69790ac9e37ae862fe8124ad7c
|
||||
Subproject commit d91d772a3a913f20eaef6c47517b9ca140edaee2
|
|
@ -424,6 +424,41 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
fused.append(a)
|
||||
return fused, False
|
||||
|
||||
def _stitch_axis(shape):
|
||||
stitch_axis = []
|
||||
size = 1
|
||||
for i in shape:
|
||||
size = size * i
|
||||
stitch_axis.append(i)
|
||||
if size >= 1024 * 8:
|
||||
return stitch_axis
|
||||
return None
|
||||
|
||||
def _same_stitch_axis(a, b):
|
||||
x = []
|
||||
x.extend(a)
|
||||
x.extend(b)
|
||||
stitch_axis = _stitch_axis(x[0].shape)
|
||||
for item in x:
|
||||
i_stitch_axis = _stitch_axis(item.shape)
|
||||
if i_stitch_axis is None or i_stitch_axis != stitch_axis:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _may_stitch(dom, a, r):
|
||||
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
|
||||
if _reduce_nums(a.ops) < 2:
|
||||
dom_outs = [op.output for op in dom.ops]
|
||||
a_ins = [input for op in a.ops for input in op.inputs]
|
||||
a_outs = [op.output for op in a.ops]
|
||||
a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins]
|
||||
stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins]
|
||||
if _same_stitch_axis(stitch_tensors, a_final_outs):
|
||||
for tensor in stitch_tensors:
|
||||
if _tensor_size(tensor) >= 1024 * 1024 * 12:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _reduce_stitch(dom):
|
||||
if dom.pattern != PrimLib.REDUCE:
|
||||
return None
|
||||
|
@ -434,12 +469,14 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
|
||||
fused = []
|
||||
for a, r in dom.out_relations.items():
|
||||
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
|
||||
if _reduce_nums(a.ops) < 2:
|
||||
# softmax
|
||||
if len(a.ops) > 4 and len(a.ops[0].inputs[0].shape) == 4:
|
||||
if _may_stitch(dom, a, r):
|
||||
if a.pattern == PrimLib.REDUCE:
|
||||
if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']:
|
||||
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
|
||||
fused.append(a)
|
||||
elif a.pattern == PrimLib.BROADCAST:
|
||||
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
|
||||
fused.append(a)
|
||||
return fused, False
|
||||
|
||||
def _transpose(dom):
|
||||
|
|
Loading…
Reference in New Issue