!22677 [GraphKernel] ascend support stitch fusion
Merge pull request !22677 from r1chardf1d0/master
This commit is contained in:
commit
315431b9a0
|
@ -19,8 +19,67 @@ from .model import PrimLib, Graph, Tensor, Operator
|
|||
from .model import DataFormat as DF
|
||||
|
||||
|
||||
def tensor_size(tensor):
|
||||
"""get tensor size"""
|
||||
size = 1
|
||||
for i in tensor.shape:
|
||||
size *= i
|
||||
return size
|
||||
|
||||
|
||||
def reduce_nums(ops):
|
||||
"""get reduce nums"""
|
||||
count = 0
|
||||
for op in ops:
|
||||
if op.prim.startswith('Reduce'):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
def may_stitch(dom, a, r, stitch_axis_size, stitch_buffer_size):
|
||||
"""check if can stitch"""
|
||||
|
||||
def _same_stitch_axis(stitch_tensors, final_outs, stitch_axis_size):
|
||||
"""does a and b have same stitch axis"""
|
||||
|
||||
def _stitch_axis(shape, stitch_axis_size):
|
||||
"""get stitch axis"""
|
||||
stitchaxis = []
|
||||
size = 1
|
||||
for i in shape:
|
||||
size = size * i
|
||||
stitchaxis.append(i)
|
||||
if size >= stitch_axis_size:
|
||||
return stitchaxis
|
||||
return None
|
||||
|
||||
x = []
|
||||
x.extend(stitch_tensors)
|
||||
x.extend(final_outs)
|
||||
stitch_axis_0 = _stitch_axis(x[0].shape, stitch_axis_size)
|
||||
for item in x:
|
||||
i_stitch_axis = _stitch_axis(item.shape, stitch_axis_size)
|
||||
if i_stitch_axis is None or i_stitch_axis != stitch_axis_0:
|
||||
return False
|
||||
return True
|
||||
|
||||
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
|
||||
if reduce_nums(a.ops) >= 2:
|
||||
return False
|
||||
dom_outs = [op.output for op in dom.ops]
|
||||
a_ins = [op_input for op in a.ops for op_input in op.inputs]
|
||||
a_outs = [op.output for op in a.ops]
|
||||
a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins]
|
||||
stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins]
|
||||
if not _same_stitch_axis(stitch_tensors, a_final_outs, stitch_axis_size):
|
||||
return False
|
||||
return any([tensor_size(tensor) >= stitch_buffer_size for tensor in stitch_tensors])
|
||||
return False
|
||||
|
||||
|
||||
class GraphSplitByPattern:
|
||||
"""Graph splitter"""
|
||||
|
||||
class ReachTable:
|
||||
"""Reachable table"""
|
||||
|
||||
|
@ -94,6 +153,7 @@ class GraphSplitByPattern:
|
|||
self.output_excluded.add(to)
|
||||
else:
|
||||
_gather_reduce_exclude(to)
|
||||
|
||||
_gather_reduce_exclude(init_op)
|
||||
self.unique_id = unique_id
|
||||
self.reach_tab = reach_tab
|
||||
|
@ -138,6 +198,7 @@ class GraphSplitByPattern:
|
|||
|
||||
def fuse(self, area):
|
||||
"""Fuse `area` to `self`"""
|
||||
|
||||
def _update_relation(relations, a, r):
|
||||
relations[a] = max(r, relations[a]) if a in relations else r
|
||||
|
||||
|
@ -243,6 +304,7 @@ class GraphSplitByPattern:
|
|||
self.areas = []
|
||||
self.flags = flags
|
||||
self.enable_recompute = self.flags.get("enable_recompute_fusion", False)
|
||||
self.enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False)
|
||||
self.reach_tab = self.ReachTable(len(graph.ops) + 1 if self.enable_recompute else len(graph.ops))
|
||||
self.area_map = {}
|
||||
_, outputs = graph.deduce_parameters()
|
||||
|
@ -256,7 +318,7 @@ class GraphSplitByPattern:
|
|||
self.set_area_map([op], a)
|
||||
for a in self.areas:
|
||||
a.link_input(self.area_map)
|
||||
for i in range(len(self.areas)-1, -1, -1):
|
||||
for i in range(len(self.areas) - 1, -1, -1):
|
||||
self.areas[i].link_output()
|
||||
if self.enable_recompute:
|
||||
self.recom_area = self.Area(None, False, idx, self.reach_tab)
|
||||
|
@ -295,6 +357,7 @@ class GraphSplitByPattern:
|
|||
|
||||
def fuse(self, selector):
|
||||
"""Fuse areas"""
|
||||
|
||||
def _fuse_area():
|
||||
for dominant in self.areas:
|
||||
result = selector(dominant)
|
||||
|
@ -327,6 +390,7 @@ class GraphSplitByPattern:
|
|||
|
||||
def hfuse(self, selector):
|
||||
"""Fuse horizontal areas with same input tensor"""
|
||||
|
||||
def _do_fuse(areas):
|
||||
for i in range(len(areas) - 1):
|
||||
sibling = []
|
||||
|
@ -349,6 +413,7 @@ class GraphSplitByPattern:
|
|||
self.areas.remove(area)
|
||||
return True
|
||||
return False
|
||||
|
||||
changed = False
|
||||
while True:
|
||||
for dom in self.areas:
|
||||
|
@ -426,6 +491,7 @@ class GraphSplitByPattern:
|
|||
|
||||
def split_output_reshapes(self):
|
||||
"""Force split the output Reshapes into other new area"""
|
||||
|
||||
def _remove_output_reshape(reshape_ops, other_ops):
|
||||
def _run():
|
||||
for op in reshape_ops:
|
||||
|
@ -434,6 +500,7 @@ class GraphSplitByPattern:
|
|||
other_ops.append(op)
|
||||
return True
|
||||
return False
|
||||
|
||||
while _run():
|
||||
pass
|
||||
|
||||
|
@ -511,6 +578,7 @@ class GraphSplitByPattern:
|
|||
|
||||
def find_cheap_regions(self, dom):
|
||||
"""extract all the cheap regions in dom area, toposort each region before return"""
|
||||
|
||||
def _grow_region(region_ops, op, weight, inputs):
|
||||
"""include op to region_ops if region grow"""
|
||||
# region successfully ends at inputs
|
||||
|
@ -559,6 +627,7 @@ class GraphSplitByPattern:
|
|||
|
||||
def select_user_area(self, tail_tensor):
|
||||
"""select the user area has only one edge to dom area"""
|
||||
|
||||
def _get_edge_num(dom_area, user_area):
|
||||
"""get edge num between two areas"""
|
||||
dom_graph = self.to_subgraph(dom_area)
|
||||
|
@ -583,8 +652,10 @@ class GraphSplitByPattern:
|
|||
|
||||
def recompute_fuse(self):
|
||||
"""find recompute regions and copy them out to new Areas"""
|
||||
|
||||
def do_recompute_fuse():
|
||||
"""split the unfusing pattern by add recompute area"""
|
||||
|
||||
def recompute_cheap_region(dom):
|
||||
for cheap_region in cheap_regions:
|
||||
user_areas = self.select_user_area(cheap_region[-1].output)
|
||||
|
@ -597,6 +668,7 @@ class GraphSplitByPattern:
|
|||
if self.recom_res:
|
||||
return True
|
||||
return False
|
||||
|
||||
recompute_suc = False
|
||||
orig_areas = []
|
||||
orig_areas.extend(self.areas)
|
||||
|
@ -628,6 +700,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
|
||||
def pattern_fuse(self, fuse_func=None):
|
||||
"""fuse Areas by pattern"""
|
||||
|
||||
def _reshape(dom):
|
||||
if dom.pattern != PrimLib.RESHAPE:
|
||||
return None
|
||||
|
@ -638,8 +711,8 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
min_area = a
|
||||
for a, _ in dom.in_relations.items():
|
||||
if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \
|
||||
len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \
|
||||
(min_area is None or a.pattern < min_area.pattern):
|
||||
len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \
|
||||
(min_area is None or a.pattern < min_area.pattern):
|
||||
min_area, forward_fuse = a, True
|
||||
return ([min_area], forward_fuse) if min_area else None
|
||||
|
||||
|
@ -718,12 +791,6 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
fused.append(a)
|
||||
return fused, True
|
||||
|
||||
def _tensor_size(tensor):
|
||||
size = 1
|
||||
for i in tensor.shape:
|
||||
size *= i
|
||||
return size
|
||||
|
||||
def _is_atomic_add_available(dom):
|
||||
if any(["Reduce" in x.prim for x in dom.ops[1:]]):
|
||||
return False
|
||||
|
@ -739,23 +806,16 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
return reduce_size >= 1024
|
||||
return True
|
||||
|
||||
def _reduce_nums(ops):
|
||||
count = 0
|
||||
for op in ops:
|
||||
if op.prim.startswith('Reduce'):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _reduce_output(dom):
|
||||
if dom.pattern != PrimLib.REDUCE:
|
||||
return None
|
||||
if _reduce_nums(dom.ops) > 1:
|
||||
if reduce_nums(dom.ops) > 1:
|
||||
return None
|
||||
if _is_atomic_add_available(dom):
|
||||
return None
|
||||
is_all_reduce = _tensor_size(dom.ops[0].output) == 1
|
||||
is_all_reduce = tensor_size(dom.ops[0].output) == 1
|
||||
# excluded large size all reduce
|
||||
if is_all_reduce and _tensor_size(dom.ops[0].inputs[0]) > 1024 * 12:
|
||||
if is_all_reduce and tensor_size(dom.ops[0].inputs[0]) > 1024 * 12:
|
||||
return None
|
||||
|
||||
fused = []
|
||||
|
@ -765,52 +825,17 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
fused.append(a)
|
||||
return fused, False
|
||||
|
||||
def _stitch_axis(shape):
|
||||
stitch_axis = []
|
||||
size = 1
|
||||
for i in shape:
|
||||
size = size * i
|
||||
stitch_axis.append(i)
|
||||
if size >= 1024 * 8:
|
||||
return stitch_axis
|
||||
return None
|
||||
|
||||
def _same_stitch_axis(a, b):
|
||||
x = []
|
||||
x.extend(a)
|
||||
x.extend(b)
|
||||
stitch_axis = _stitch_axis(x[0].shape)
|
||||
for item in x:
|
||||
i_stitch_axis = _stitch_axis(item.shape)
|
||||
if i_stitch_axis is None or i_stitch_axis != stitch_axis:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _may_stitch(dom, a, r):
|
||||
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
|
||||
if _reduce_nums(a.ops) >= 2:
|
||||
return False
|
||||
dom_outs = [op.output for op in dom.ops]
|
||||
a_ins = [op_input for op in a.ops for op_input in op.inputs]
|
||||
a_outs = [op.output for op in a.ops]
|
||||
a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins]
|
||||
stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins]
|
||||
if not _same_stitch_axis(stitch_tensors, a_final_outs):
|
||||
return False
|
||||
return any([_tensor_size(tensor) >= 1024 * 1024 for tensor in stitch_tensors])
|
||||
return False
|
||||
|
||||
def _reduce_stitch(dom):
|
||||
if dom.pattern != PrimLib.REDUCE:
|
||||
return None
|
||||
if _tensor_size(dom.ops[0].output) == 1:
|
||||
if tensor_size(dom.ops[0].output) == 1:
|
||||
return None
|
||||
if _tensor_size(dom.ops[0].inputs[0]) < 1024 * 12:
|
||||
if tensor_size(dom.ops[0].inputs[0]) < 1024 * 12:
|
||||
return None
|
||||
|
||||
fused = []
|
||||
for a, r in dom.out_relations.items():
|
||||
if not _may_stitch(dom, a, r):
|
||||
if not may_stitch(dom, a, r, 1024 * 8, 1024 * 1024):
|
||||
continue
|
||||
if a.pattern == PrimLib.REDUCE:
|
||||
if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']:
|
||||
|
@ -881,7 +906,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
appected_areas = {"TensorScatterAdd", "UnsortedSegmentSum"}
|
||||
for a, _ in dom.out_relations.items():
|
||||
if _shape_consistent(gather_prims, appected_areas, dom, a) and \
|
||||
_count_target_prim(a.ops + dom.ops, appected_areas) < 2 and dom.check_acyclic(a):
|
||||
_count_target_prim(a.ops + dom.ops, appected_areas) < 2 and dom.check_acyclic(a):
|
||||
return [a], False
|
||||
return None
|
||||
|
||||
|
@ -916,8 +941,8 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
for a in sibling:
|
||||
op = a.ops[0]
|
||||
if a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \
|
||||
PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \
|
||||
dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis"):
|
||||
PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \
|
||||
dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis"):
|
||||
fused.append(a)
|
||||
return fused
|
||||
|
||||
|
@ -927,7 +952,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
fused = []
|
||||
for a in sibling:
|
||||
if a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \
|
||||
dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape:
|
||||
dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape:
|
||||
fused.append(a)
|
||||
return fused
|
||||
|
||||
|
@ -945,7 +970,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
changed = self.fuse(_broadcast_opaque) or changed
|
||||
changed = self.fuse(_gather_output) or changed
|
||||
changed = self.fuse(_reduce_output) or changed
|
||||
if enable_stitch_fusion:
|
||||
if self.enable_stitch_fusion:
|
||||
changed = self.fuse(_reduce_stitch) or changed
|
||||
self.fuse(_transpose)
|
||||
self.hfuse(_h_broadcast)
|
||||
|
@ -957,12 +982,11 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
|
||||
fuse_func(_broadcast_width):
|
||||
return
|
||||
if fuse_func(_reduce_output) or (enable_stitch_fusion and fuse_func(_reduce_stitch)):
|
||||
if fuse_func(_reduce_output) or (self.enable_stitch_fusion and fuse_func(_reduce_stitch)):
|
||||
return
|
||||
fuse_func(_transpose)
|
||||
return
|
||||
|
||||
enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False)
|
||||
if fuse_func is None:
|
||||
_fuse_loop()
|
||||
else:
|
||||
|
@ -976,6 +1000,7 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
|
||||
def get_default_mode(self, op):
|
||||
"""Get efault mode for Ascend"""
|
||||
|
||||
def _dtype_same(tensors):
|
||||
dtype = tensors[0].dtype
|
||||
for tensor_ in tensors:
|
||||
|
@ -992,15 +1017,10 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
|
||||
def pattern_fuse(self, fuse_func=None):
|
||||
"""fuse Areas by pattern"""
|
||||
def _tensor_size(tensor):
|
||||
size = 1
|
||||
for i in tensor.shape:
|
||||
size *= i
|
||||
return size
|
||||
|
||||
def _likely_multicore(dom):
|
||||
op = dom.dom_op()
|
||||
iter_size = _tensor_size(op.output if not PrimLib.is_reduce(op) else op.inputs[0])
|
||||
iter_size = tensor_size(op.output if not PrimLib.is_reduce(op) else op.inputs[0])
|
||||
return iter_size > 1024
|
||||
|
||||
def _reshape(dom):
|
||||
|
@ -1013,8 +1033,8 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
min_area = a
|
||||
for a, _ in dom.in_relations.items():
|
||||
if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \
|
||||
len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \
|
||||
(min_area is None or a.pattern < min_area.pattern):
|
||||
len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \
|
||||
(min_area is None or a.pattern < min_area.pattern):
|
||||
min_area, forward_fuse = a, True
|
||||
return ([min_area], forward_fuse) if min_area else None
|
||||
|
||||
|
@ -1111,6 +1131,27 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
fused.append(a)
|
||||
return fused, False
|
||||
|
||||
def _reduce_stitch(dom):
|
||||
if dom.pattern != PrimLib.REDUCE:
|
||||
return None
|
||||
if tensor_size(dom.ops[0].output) == 1:
|
||||
return None
|
||||
if tensor_size(dom.ops[0].inputs[0]) < 32 * 16 * 16:
|
||||
return None
|
||||
|
||||
fused = []
|
||||
for a, r in dom.out_relations.items():
|
||||
if not may_stitch(dom, a, r, 32, 32 * 16 * 16):
|
||||
continue
|
||||
if a.pattern == PrimLib.REDUCE:
|
||||
if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']:
|
||||
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
|
||||
fused.append(a)
|
||||
elif a.pattern <= PrimLib.BROADCAST:
|
||||
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
|
||||
fused.append(a)
|
||||
return fused, False
|
||||
|
||||
def _transdata_pattern_support(dom, a):
|
||||
transdata_op = dom.dom_op()
|
||||
|
||||
|
@ -1127,6 +1168,7 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
if dim % cube_size != 0:
|
||||
res = True
|
||||
return res
|
||||
|
||||
has_pad = _has_pad()
|
||||
if has_pad:
|
||||
return False
|
||||
|
@ -1145,7 +1187,7 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
if src_format == DF.FRAC_NZ and dst_format in (DF.DEFAULT, DF.NCHW):
|
||||
return True
|
||||
# For the Default/NCHW to FRAC_NZ, currently only the Cast+Transdata is supported
|
||||
if src_format in (DF.DEFAULT, DF.NCHW) and dst_format == DF.FRAC_NZ\
|
||||
if src_format in (DF.DEFAULT, DF.NCHW) and dst_format == DF.FRAC_NZ \
|
||||
and len(a.ops) == 1 and a.dom_op().prim == "Cast" and not a.is_output:
|
||||
return True
|
||||
return False
|
||||
|
@ -1171,6 +1213,8 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
changed = self.fuse(_broadcast_width) or changed
|
||||
changed = self.fuse(_matmul_depth) or changed
|
||||
changed = self.fuse(_reduce_output) or changed
|
||||
if self.enable_stitch_fusion:
|
||||
changed = self.fuse(_reduce_stitch) or changed
|
||||
self.fuse(_transdata)
|
||||
|
||||
def _fuse_once(fuse_func):
|
||||
|
|
|
@ -101,6 +101,7 @@ class OpInfer:
|
|||
|
||||
class _Elemwise(OpInfer):
|
||||
"""Common infer for elementwise operators"""
|
||||
|
||||
@staticmethod
|
||||
def broadcast_shape(shapes):
|
||||
"""deduce broadcast shape using same rules as numpy"""
|
||||
|
@ -120,25 +121,24 @@ class _Elemwise(OpInfer):
|
|||
@staticmethod
|
||||
def defaultformat_to_nz(default_shape):
|
||||
"""default format shape to fractal_Nz format shape"""
|
||||
if len(default_shape) not in (1, 2):
|
||||
raise GKException("shape is too long!")
|
||||
more_two_d_shape, two_d_shape = default_shape[:-2], default_shape[-2:]
|
||||
# (32) or (1, 32) -> (2, 1, 1, 16)
|
||||
if len(default_shape) == 1 or (len(default_shape) == 2 and default_shape[0] == 1):
|
||||
shape = [default_shape[-1] // 16, 1, 1, 16]
|
||||
if default_shape[-1] % 16 != 0:
|
||||
if len(two_d_shape) == 1 or (len(two_d_shape) == 2 and two_d_shape[0] == 1):
|
||||
shape = [two_d_shape[-1] // 16, 1, 1, 16]
|
||||
if two_d_shape[-1] % 16 != 0:
|
||||
raise GKException("should be multiplies of 16")
|
||||
return shape
|
||||
# (32, 1) -> (1, 2, 16, 1)
|
||||
if len(default_shape) == 2 and default_shape[1] == 1:
|
||||
shape = [1, default_shape[0] // 16, 16, 1]
|
||||
if default_shape[0] % 16 != 0:
|
||||
if len(two_d_shape) == 2 and two_d_shape[1] == 1:
|
||||
shape = [1, two_d_shape[0] // 16, 16, 1]
|
||||
if two_d_shape[0] % 16 != 0:
|
||||
raise GKException("should be multiples of 16")
|
||||
return shape
|
||||
# (32, 48) -> (3, 2, 16, 16)
|
||||
shape = [default_shape[1] // 16, default_shape[0] // 16, 16, 16]
|
||||
if default_shape[0] % 16 != 0 or default_shape[1] % 16 != 0:
|
||||
shape = [two_d_shape[1] // 16, two_d_shape[0] // 16, 16, 16]
|
||||
if two_d_shape[0] % 16 != 0 or two_d_shape[1] % 16 != 0:
|
||||
raise GKException("should be multiples of 16")
|
||||
return shape
|
||||
return more_two_d_shape + shape
|
||||
|
||||
def _infer_shape(self):
|
||||
"""returns the output shape with broadcast"""
|
||||
|
|
Loading…
Reference in New Issue