!22677 [GraphKernel] ascend support stitch fusion

Merge pull request !22677 from r1chardf1d0/master
This commit is contained in:
i-robot 2021-09-02 03:31:14 +00:00 committed by Gitee
commit 315431b9a0
2 changed files with 128 additions and 84 deletions

View File

@ -19,8 +19,67 @@ from .model import PrimLib, Graph, Tensor, Operator
from .model import DataFormat as DF from .model import DataFormat as DF
def tensor_size(tensor):
"""get tensor size"""
size = 1
for i in tensor.shape:
size *= i
return size
def reduce_nums(ops):
"""get reduce nums"""
count = 0
for op in ops:
if op.prim.startswith('Reduce'):
count += 1
return count
def may_stitch(dom, a, r, stitch_axis_size, stitch_buffer_size):
"""check if can stitch"""
def _same_stitch_axis(stitch_tensors, final_outs, stitch_axis_size):
"""does a and b have same stitch axis"""
def _stitch_axis(shape, stitch_axis_size):
"""get stitch axis"""
stitchaxis = []
size = 1
for i in shape:
size = size * i
stitchaxis.append(i)
if size >= stitch_axis_size:
return stitchaxis
return None
x = []
x.extend(stitch_tensors)
x.extend(final_outs)
stitch_axis_0 = _stitch_axis(x[0].shape, stitch_axis_size)
for item in x:
i_stitch_axis = _stitch_axis(item.shape, stitch_axis_size)
if i_stitch_axis is None or i_stitch_axis != stitch_axis_0:
return False
return True
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
if reduce_nums(a.ops) >= 2:
return False
dom_outs = [op.output for op in dom.ops]
a_ins = [op_input for op in a.ops for op_input in op.inputs]
a_outs = [op.output for op in a.ops]
a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins]
stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins]
if not _same_stitch_axis(stitch_tensors, a_final_outs, stitch_axis_size):
return False
return any([tensor_size(tensor) >= stitch_buffer_size for tensor in stitch_tensors])
return False
class GraphSplitByPattern: class GraphSplitByPattern:
"""Graph splitter""" """Graph splitter"""
class ReachTable: class ReachTable:
"""Reachable table""" """Reachable table"""
@ -94,6 +153,7 @@ class GraphSplitByPattern:
self.output_excluded.add(to) self.output_excluded.add(to)
else: else:
_gather_reduce_exclude(to) _gather_reduce_exclude(to)
_gather_reduce_exclude(init_op) _gather_reduce_exclude(init_op)
self.unique_id = unique_id self.unique_id = unique_id
self.reach_tab = reach_tab self.reach_tab = reach_tab
@ -138,6 +198,7 @@ class GraphSplitByPattern:
def fuse(self, area): def fuse(self, area):
"""Fuse `area` to `self`""" """Fuse `area` to `self`"""
def _update_relation(relations, a, r): def _update_relation(relations, a, r):
relations[a] = max(r, relations[a]) if a in relations else r relations[a] = max(r, relations[a]) if a in relations else r
@ -243,6 +304,7 @@ class GraphSplitByPattern:
self.areas = [] self.areas = []
self.flags = flags self.flags = flags
self.enable_recompute = self.flags.get("enable_recompute_fusion", False) self.enable_recompute = self.flags.get("enable_recompute_fusion", False)
self.enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False)
self.reach_tab = self.ReachTable(len(graph.ops) + 1 if self.enable_recompute else len(graph.ops)) self.reach_tab = self.ReachTable(len(graph.ops) + 1 if self.enable_recompute else len(graph.ops))
self.area_map = {} self.area_map = {}
_, outputs = graph.deduce_parameters() _, outputs = graph.deduce_parameters()
@ -256,7 +318,7 @@ class GraphSplitByPattern:
self.set_area_map([op], a) self.set_area_map([op], a)
for a in self.areas: for a in self.areas:
a.link_input(self.area_map) a.link_input(self.area_map)
for i in range(len(self.areas)-1, -1, -1): for i in range(len(self.areas) - 1, -1, -1):
self.areas[i].link_output() self.areas[i].link_output()
if self.enable_recompute: if self.enable_recompute:
self.recom_area = self.Area(None, False, idx, self.reach_tab) self.recom_area = self.Area(None, False, idx, self.reach_tab)
@ -295,6 +357,7 @@ class GraphSplitByPattern:
def fuse(self, selector): def fuse(self, selector):
"""Fuse areas""" """Fuse areas"""
def _fuse_area(): def _fuse_area():
for dominant in self.areas: for dominant in self.areas:
result = selector(dominant) result = selector(dominant)
@ -327,6 +390,7 @@ class GraphSplitByPattern:
def hfuse(self, selector): def hfuse(self, selector):
"""Fuse horizontal areas with same input tensor""" """Fuse horizontal areas with same input tensor"""
def _do_fuse(areas): def _do_fuse(areas):
for i in range(len(areas) - 1): for i in range(len(areas) - 1):
sibling = [] sibling = []
@ -349,6 +413,7 @@ class GraphSplitByPattern:
self.areas.remove(area) self.areas.remove(area)
return True return True
return False return False
changed = False changed = False
while True: while True:
for dom in self.areas: for dom in self.areas:
@ -426,6 +491,7 @@ class GraphSplitByPattern:
def split_output_reshapes(self): def split_output_reshapes(self):
"""Force split the output Reshapes into other new area""" """Force split the output Reshapes into other new area"""
def _remove_output_reshape(reshape_ops, other_ops): def _remove_output_reshape(reshape_ops, other_ops):
def _run(): def _run():
for op in reshape_ops: for op in reshape_ops:
@ -434,6 +500,7 @@ class GraphSplitByPattern:
other_ops.append(op) other_ops.append(op)
return True return True
return False return False
while _run(): while _run():
pass pass
@ -511,6 +578,7 @@ class GraphSplitByPattern:
def find_cheap_regions(self, dom): def find_cheap_regions(self, dom):
"""extract all the cheap regions in dom area, toposort each region before return""" """extract all the cheap regions in dom area, toposort each region before return"""
def _grow_region(region_ops, op, weight, inputs): def _grow_region(region_ops, op, weight, inputs):
"""include op to region_ops if region grow""" """include op to region_ops if region grow"""
# region successfully ends at inputs # region successfully ends at inputs
@ -559,6 +627,7 @@ class GraphSplitByPattern:
def select_user_area(self, tail_tensor): def select_user_area(self, tail_tensor):
"""select the user area has only one edge to dom area""" """select the user area has only one edge to dom area"""
def _get_edge_num(dom_area, user_area): def _get_edge_num(dom_area, user_area):
"""get edge num between two areas""" """get edge num between two areas"""
dom_graph = self.to_subgraph(dom_area) dom_graph = self.to_subgraph(dom_area)
@ -583,8 +652,10 @@ class GraphSplitByPattern:
def recompute_fuse(self): def recompute_fuse(self):
"""find recompute regions and copy them out to new Areas""" """find recompute regions and copy them out to new Areas"""
def do_recompute_fuse(): def do_recompute_fuse():
"""split the unfusing pattern by add recompute area""" """split the unfusing pattern by add recompute area"""
def recompute_cheap_region(dom): def recompute_cheap_region(dom):
for cheap_region in cheap_regions: for cheap_region in cheap_regions:
user_areas = self.select_user_area(cheap_region[-1].output) user_areas = self.select_user_area(cheap_region[-1].output)
@ -597,6 +668,7 @@ class GraphSplitByPattern:
if self.recom_res: if self.recom_res:
return True return True
return False return False
recompute_suc = False recompute_suc = False
orig_areas = [] orig_areas = []
orig_areas.extend(self.areas) orig_areas.extend(self.areas)
@ -628,6 +700,7 @@ class GraphSplitGpu(GraphSplitByPattern):
def pattern_fuse(self, fuse_func=None): def pattern_fuse(self, fuse_func=None):
"""fuse Areas by pattern""" """fuse Areas by pattern"""
def _reshape(dom): def _reshape(dom):
if dom.pattern != PrimLib.RESHAPE: if dom.pattern != PrimLib.RESHAPE:
return None return None
@ -638,8 +711,8 @@ class GraphSplitGpu(GraphSplitByPattern):
min_area = a min_area = a
for a, _ in dom.in_relations.items(): for a, _ in dom.in_relations.items():
if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \ if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \
len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \ len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \
(min_area is None or a.pattern < min_area.pattern): (min_area is None or a.pattern < min_area.pattern):
min_area, forward_fuse = a, True min_area, forward_fuse = a, True
return ([min_area], forward_fuse) if min_area else None return ([min_area], forward_fuse) if min_area else None
@ -718,12 +791,6 @@ class GraphSplitGpu(GraphSplitByPattern):
fused.append(a) fused.append(a)
return fused, True return fused, True
def _tensor_size(tensor):
size = 1
for i in tensor.shape:
size *= i
return size
def _is_atomic_add_available(dom): def _is_atomic_add_available(dom):
if any(["Reduce" in x.prim for x in dom.ops[1:]]): if any(["Reduce" in x.prim for x in dom.ops[1:]]):
return False return False
@ -739,23 +806,16 @@ class GraphSplitGpu(GraphSplitByPattern):
return reduce_size >= 1024 return reduce_size >= 1024
return True return True
def _reduce_nums(ops):
count = 0
for op in ops:
if op.prim.startswith('Reduce'):
count += 1
return count
def _reduce_output(dom): def _reduce_output(dom):
if dom.pattern != PrimLib.REDUCE: if dom.pattern != PrimLib.REDUCE:
return None return None
if _reduce_nums(dom.ops) > 1: if reduce_nums(dom.ops) > 1:
return None return None
if _is_atomic_add_available(dom): if _is_atomic_add_available(dom):
return None return None
is_all_reduce = _tensor_size(dom.ops[0].output) == 1 is_all_reduce = tensor_size(dom.ops[0].output) == 1
# excluded large size all reduce # excluded large size all reduce
if is_all_reduce and _tensor_size(dom.ops[0].inputs[0]) > 1024 * 12: if is_all_reduce and tensor_size(dom.ops[0].inputs[0]) > 1024 * 12:
return None return None
fused = [] fused = []
@ -765,52 +825,17 @@ class GraphSplitGpu(GraphSplitByPattern):
fused.append(a) fused.append(a)
return fused, False return fused, False
def _stitch_axis(shape):
stitch_axis = []
size = 1
for i in shape:
size = size * i
stitch_axis.append(i)
if size >= 1024 * 8:
return stitch_axis
return None
def _same_stitch_axis(a, b):
x = []
x.extend(a)
x.extend(b)
stitch_axis = _stitch_axis(x[0].shape)
for item in x:
i_stitch_axis = _stitch_axis(item.shape)
if i_stitch_axis is None or i_stitch_axis != stitch_axis:
return False
return True
def _may_stitch(dom, a, r):
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
if _reduce_nums(a.ops) >= 2:
return False
dom_outs = [op.output for op in dom.ops]
a_ins = [op_input for op in a.ops for op_input in op.inputs]
a_outs = [op.output for op in a.ops]
a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins]
stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins]
if not _same_stitch_axis(stitch_tensors, a_final_outs):
return False
return any([_tensor_size(tensor) >= 1024 * 1024 for tensor in stitch_tensors])
return False
def _reduce_stitch(dom): def _reduce_stitch(dom):
if dom.pattern != PrimLib.REDUCE: if dom.pattern != PrimLib.REDUCE:
return None return None
if _tensor_size(dom.ops[0].output) == 1: if tensor_size(dom.ops[0].output) == 1:
return None return None
if _tensor_size(dom.ops[0].inputs[0]) < 1024 * 12: if tensor_size(dom.ops[0].inputs[0]) < 1024 * 12:
return None return None
fused = [] fused = []
for a, r in dom.out_relations.items(): for a, r in dom.out_relations.items():
if not _may_stitch(dom, a, r): if not may_stitch(dom, a, r, 1024 * 8, 1024 * 1024):
continue continue
if a.pattern == PrimLib.REDUCE: if a.pattern == PrimLib.REDUCE:
if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']: if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']:
@ -881,7 +906,7 @@ class GraphSplitGpu(GraphSplitByPattern):
appected_areas = {"TensorScatterAdd", "UnsortedSegmentSum"} appected_areas = {"TensorScatterAdd", "UnsortedSegmentSum"}
for a, _ in dom.out_relations.items(): for a, _ in dom.out_relations.items():
if _shape_consistent(gather_prims, appected_areas, dom, a) and \ if _shape_consistent(gather_prims, appected_areas, dom, a) and \
_count_target_prim(a.ops + dom.ops, appected_areas) < 2 and dom.check_acyclic(a): _count_target_prim(a.ops + dom.ops, appected_areas) < 2 and dom.check_acyclic(a):
return [a], False return [a], False
return None return None
@ -916,8 +941,8 @@ class GraphSplitGpu(GraphSplitByPattern):
for a in sibling: for a in sibling:
op = a.ops[0] op = a.ops[0]
if a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \ if a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \
PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \ PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \
dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis"): dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis"):
fused.append(a) fused.append(a)
return fused return fused
@ -927,7 +952,7 @@ class GraphSplitGpu(GraphSplitByPattern):
fused = [] fused = []
for a in sibling: for a in sibling:
if a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \ if a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \
dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape: dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape:
fused.append(a) fused.append(a)
return fused return fused
@ -945,7 +970,7 @@ class GraphSplitGpu(GraphSplitByPattern):
changed = self.fuse(_broadcast_opaque) or changed changed = self.fuse(_broadcast_opaque) or changed
changed = self.fuse(_gather_output) or changed changed = self.fuse(_gather_output) or changed
changed = self.fuse(_reduce_output) or changed changed = self.fuse(_reduce_output) or changed
if enable_stitch_fusion: if self.enable_stitch_fusion:
changed = self.fuse(_reduce_stitch) or changed changed = self.fuse(_reduce_stitch) or changed
self.fuse(_transpose) self.fuse(_transpose)
self.hfuse(_h_broadcast) self.hfuse(_h_broadcast)
@ -957,12 +982,11 @@ class GraphSplitGpu(GraphSplitByPattern):
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \ fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
fuse_func(_broadcast_width): fuse_func(_broadcast_width):
return return
if fuse_func(_reduce_output) or (enable_stitch_fusion and fuse_func(_reduce_stitch)): if fuse_func(_reduce_output) or (self.enable_stitch_fusion and fuse_func(_reduce_stitch)):
return return
fuse_func(_transpose) fuse_func(_transpose)
return return
enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False)
if fuse_func is None: if fuse_func is None:
_fuse_loop() _fuse_loop()
else: else:
@ -976,6 +1000,7 @@ class GraphSplitAscend(GraphSplitByPattern):
def get_default_mode(self, op): def get_default_mode(self, op):
"""Get efault mode for Ascend""" """Get efault mode for Ascend"""
def _dtype_same(tensors): def _dtype_same(tensors):
dtype = tensors[0].dtype dtype = tensors[0].dtype
for tensor_ in tensors: for tensor_ in tensors:
@ -992,15 +1017,10 @@ class GraphSplitAscend(GraphSplitByPattern):
def pattern_fuse(self, fuse_func=None): def pattern_fuse(self, fuse_func=None):
"""fuse Areas by pattern""" """fuse Areas by pattern"""
def _tensor_size(tensor):
size = 1
for i in tensor.shape:
size *= i
return size
def _likely_multicore(dom): def _likely_multicore(dom):
op = dom.dom_op() op = dom.dom_op()
iter_size = _tensor_size(op.output if not PrimLib.is_reduce(op) else op.inputs[0]) iter_size = tensor_size(op.output if not PrimLib.is_reduce(op) else op.inputs[0])
return iter_size > 1024 return iter_size > 1024
def _reshape(dom): def _reshape(dom):
@ -1013,8 +1033,8 @@ class GraphSplitAscend(GraphSplitByPattern):
min_area = a min_area = a
for a, _ in dom.in_relations.items(): for a, _ in dom.in_relations.items():
if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \ if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \
len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \ len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \
(min_area is None or a.pattern < min_area.pattern): (min_area is None or a.pattern < min_area.pattern):
min_area, forward_fuse = a, True min_area, forward_fuse = a, True
return ([min_area], forward_fuse) if min_area else None return ([min_area], forward_fuse) if min_area else None
@ -1111,6 +1131,27 @@ class GraphSplitAscend(GraphSplitByPattern):
fused.append(a) fused.append(a)
return fused, False return fused, False
def _reduce_stitch(dom):
if dom.pattern != PrimLib.REDUCE:
return None
if tensor_size(dom.ops[0].output) == 1:
return None
if tensor_size(dom.ops[0].inputs[0]) < 32 * 16 * 16:
return None
fused = []
for a, r in dom.out_relations.items():
if not may_stitch(dom, a, r, 32, 32 * 16 * 16):
continue
if a.pattern == PrimLib.REDUCE:
if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']:
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
fused.append(a)
elif a.pattern <= PrimLib.BROADCAST:
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
fused.append(a)
return fused, False
def _transdata_pattern_support(dom, a): def _transdata_pattern_support(dom, a):
transdata_op = dom.dom_op() transdata_op = dom.dom_op()
@ -1127,6 +1168,7 @@ class GraphSplitAscend(GraphSplitByPattern):
if dim % cube_size != 0: if dim % cube_size != 0:
res = True res = True
return res return res
has_pad = _has_pad() has_pad = _has_pad()
if has_pad: if has_pad:
return False return False
@ -1145,7 +1187,7 @@ class GraphSplitAscend(GraphSplitByPattern):
if src_format == DF.FRAC_NZ and dst_format in (DF.DEFAULT, DF.NCHW): if src_format == DF.FRAC_NZ and dst_format in (DF.DEFAULT, DF.NCHW):
return True return True
# For the Default/NCHW to FRAC_NZ, currently only the Cast+Transdata is supported # For the Default/NCHW to FRAC_NZ, currently only the Cast+Transdata is supported
if src_format in (DF.DEFAULT, DF.NCHW) and dst_format == DF.FRAC_NZ\ if src_format in (DF.DEFAULT, DF.NCHW) and dst_format == DF.FRAC_NZ \
and len(a.ops) == 1 and a.dom_op().prim == "Cast" and not a.is_output: and len(a.ops) == 1 and a.dom_op().prim == "Cast" and not a.is_output:
return True return True
return False return False
@ -1171,6 +1213,8 @@ class GraphSplitAscend(GraphSplitByPattern):
changed = self.fuse(_broadcast_width) or changed changed = self.fuse(_broadcast_width) or changed
changed = self.fuse(_matmul_depth) or changed changed = self.fuse(_matmul_depth) or changed
changed = self.fuse(_reduce_output) or changed changed = self.fuse(_reduce_output) or changed
if self.enable_stitch_fusion:
changed = self.fuse(_reduce_stitch) or changed
self.fuse(_transdata) self.fuse(_transdata)
def _fuse_once(fuse_func): def _fuse_once(fuse_func):

View File

@ -101,6 +101,7 @@ class OpInfer:
class _Elemwise(OpInfer): class _Elemwise(OpInfer):
"""Common infer for elementwise operators""" """Common infer for elementwise operators"""
@staticmethod @staticmethod
def broadcast_shape(shapes): def broadcast_shape(shapes):
"""deduce broadcast shape using same rules as numpy""" """deduce broadcast shape using same rules as numpy"""
@ -120,25 +121,24 @@ class _Elemwise(OpInfer):
@staticmethod @staticmethod
def defaultformat_to_nz(default_shape): def defaultformat_to_nz(default_shape):
"""default format shape to fractal_Nz format shape""" """default format shape to fractal_Nz format shape"""
if len(default_shape) not in (1, 2): more_two_d_shape, two_d_shape = default_shape[:-2], default_shape[-2:]
raise GKException("shape is too long!")
# (32) or (1, 32) -> (2, 1, 1, 16) # (32) or (1, 32) -> (2, 1, 1, 16)
if len(default_shape) == 1 or (len(default_shape) == 2 and default_shape[0] == 1): if len(two_d_shape) == 1 or (len(two_d_shape) == 2 and two_d_shape[0] == 1):
shape = [default_shape[-1] // 16, 1, 1, 16] shape = [two_d_shape[-1] // 16, 1, 1, 16]
if default_shape[-1] % 16 != 0: if two_d_shape[-1] % 16 != 0:
raise GKException("should be multiplies of 16") raise GKException("should be multiplies of 16")
return shape return shape
# (32, 1) -> (1, 2, 16, 1) # (32, 1) -> (1, 2, 16, 1)
if len(default_shape) == 2 and default_shape[1] == 1: if len(two_d_shape) == 2 and two_d_shape[1] == 1:
shape = [1, default_shape[0] // 16, 16, 1] shape = [1, two_d_shape[0] // 16, 16, 1]
if default_shape[0] % 16 != 0: if two_d_shape[0] % 16 != 0:
raise GKException("should be multiples of 16") raise GKException("should be multiples of 16")
return shape return shape
# (32, 48) -> (3, 2, 16, 16) # (32, 48) -> (3, 2, 16, 16)
shape = [default_shape[1] // 16, default_shape[0] // 16, 16, 16] shape = [two_d_shape[1] // 16, two_d_shape[0] // 16, 16, 16]
if default_shape[0] % 16 != 0 or default_shape[1] % 16 != 0: if two_d_shape[0] % 16 != 0 or two_d_shape[1] % 16 != 0:
raise GKException("should be multiples of 16") raise GKException("should be multiples of 16")
return shape return more_two_d_shape + shape
def _infer_shape(self): def _infer_shape(self):
"""returns the output shape with broadcast""" """returns the output shape with broadcast"""