From c393deffcb3b402481c595505b8181af3fb39081 Mon Sep 17 00:00:00 2001 From: r1chardf1d0 Date: Tue, 31 Aug 2021 14:50:53 +0800 Subject: [PATCH] [GraphKernel] ascend support stitch fusion --- .../graph_kernel/model/graph_split.py | 190 +++++++++++------- .../_extends/graph_kernel/model/op_infer.py | 22 +- 2 files changed, 128 insertions(+), 84 deletions(-) diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index 06d7636149f..4de3c049742 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -19,8 +19,67 @@ from .model import PrimLib, Graph, Tensor, Operator from .model import DataFormat as DF +def tensor_size(tensor): + """get tensor size""" + size = 1 + for i in tensor.shape: + size *= i + return size + + +def reduce_nums(ops): + """get reduce nums""" + count = 0 + for op in ops: + if op.prim.startswith('Reduce'): + count += 1 + return count + + +def may_stitch(dom, a, r, stitch_axis_size, stitch_buffer_size): + """check if can stitch""" + + def _same_stitch_axis(stitch_tensors, final_outs, stitch_axis_size): + """does a and b have same stitch axis""" + + def _stitch_axis(shape, stitch_axis_size): + """get stitch axis""" + stitchaxis = [] + size = 1 + for i in shape: + size = size * i + stitchaxis.append(i) + if size >= stitch_axis_size: + return stitchaxis + return None + + x = [] + x.extend(stitch_tensors) + x.extend(final_outs) + stitch_axis_0 = _stitch_axis(x[0].shape, stitch_axis_size) + for item in x: + i_stitch_axis = _stitch_axis(item.shape, stitch_axis_size) + if i_stitch_axis is None or i_stitch_axis != stitch_axis_0: + return False + return True + + if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a): + if reduce_nums(a.ops) >= 2: + return False + dom_outs = [op.output for op in dom.ops] + a_ins = [op_input for op in a.ops for op_input in op.inputs] + a_outs = [op.output for op in a.ops] + a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins] + stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins] + if not _same_stitch_axis(stitch_tensors, a_final_outs, stitch_axis_size): + return False + return any([tensor_size(tensor) >= stitch_buffer_size for tensor in stitch_tensors]) + return False + + class GraphSplitByPattern: """Graph splitter""" + class ReachTable: """Reachable table""" @@ -94,6 +153,7 @@ class GraphSplitByPattern: self.output_excluded.add(to) else: _gather_reduce_exclude(to) + _gather_reduce_exclude(init_op) self.unique_id = unique_id self.reach_tab = reach_tab @@ -138,6 +198,7 @@ class GraphSplitByPattern: def fuse(self, area): """Fuse `area` to `self`""" + def _update_relation(relations, a, r): relations[a] = max(r, relations[a]) if a in relations else r @@ -243,6 +304,7 @@ class GraphSplitByPattern: self.areas = [] self.flags = flags self.enable_recompute = self.flags.get("enable_recompute_fusion", False) + self.enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False) self.reach_tab = self.ReachTable(len(graph.ops) + 1 if self.enable_recompute else len(graph.ops)) self.area_map = {} _, outputs = graph.deduce_parameters() @@ -256,7 +318,7 @@ class GraphSplitByPattern: self.set_area_map([op], a) for a in self.areas: a.link_input(self.area_map) - for i in range(len(self.areas)-1, -1, -1): + for i in range(len(self.areas) - 1, -1, -1): self.areas[i].link_output() if self.enable_recompute: self.recom_area = self.Area(None, False, idx, self.reach_tab) @@ -295,6 +357,7 @@ class GraphSplitByPattern: def fuse(self, selector): """Fuse areas""" + def _fuse_area(): for dominant in self.areas: result = selector(dominant) @@ -327,6 +390,7 @@ class GraphSplitByPattern: def hfuse(self, selector): """Fuse horizontal areas with same input tensor""" + def _do_fuse(areas): for i in range(len(areas) - 1): sibling = [] @@ -349,6 +413,7 @@ class GraphSplitByPattern: self.areas.remove(area) return True return False + changed = False while True: for dom in self.areas: @@ -426,6 +491,7 @@ class GraphSplitByPattern: def split_output_reshapes(self): """Force split the output Reshapes into other new area""" + def _remove_output_reshape(reshape_ops, other_ops): def _run(): for op in reshape_ops: @@ -434,6 +500,7 @@ class GraphSplitByPattern: other_ops.append(op) return True return False + while _run(): pass @@ -511,6 +578,7 @@ class GraphSplitByPattern: def find_cheap_regions(self, dom): """extract all the cheap regions in dom area, toposort each region before return""" + def _grow_region(region_ops, op, weight, inputs): """include op to region_ops if region grow""" # region successfully ends at inputs @@ -559,6 +627,7 @@ class GraphSplitByPattern: def select_user_area(self, tail_tensor): """select the user area has only one edge to dom area""" + def _get_edge_num(dom_area, user_area): """get edge num between two areas""" dom_graph = self.to_subgraph(dom_area) @@ -583,8 +652,10 @@ class GraphSplitByPattern: def recompute_fuse(self): """find recompute regions and copy them out to new Areas""" + def do_recompute_fuse(): """split the unfusing pattern by add recompute area""" + def recompute_cheap_region(dom): for cheap_region in cheap_regions: user_areas = self.select_user_area(cheap_region[-1].output) @@ -597,6 +668,7 @@ class GraphSplitByPattern: if self.recom_res: return True return False + recompute_suc = False orig_areas = [] orig_areas.extend(self.areas) @@ -628,6 +700,7 @@ class GraphSplitGpu(GraphSplitByPattern): def pattern_fuse(self, fuse_func=None): """fuse Areas by pattern""" + def _reshape(dom): if dom.pattern != PrimLib.RESHAPE: return None @@ -638,8 +711,8 @@ class GraphSplitGpu(GraphSplitByPattern): min_area = a for a, _ in dom.in_relations.items(): if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \ - len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \ - (min_area is None or a.pattern < min_area.pattern): + len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \ + (min_area is None or a.pattern < min_area.pattern): min_area, forward_fuse = a, True return ([min_area], forward_fuse) if min_area else None @@ -718,12 +791,6 @@ class GraphSplitGpu(GraphSplitByPattern): fused.append(a) return fused, True - def _tensor_size(tensor): - size = 1 - for i in tensor.shape: - size *= i - return size - def _is_atomic_add_available(dom): if any(["Reduce" in x.prim for x in dom.ops[1:]]): return False @@ -739,23 +806,16 @@ class GraphSplitGpu(GraphSplitByPattern): return reduce_size >= 1024 return True - def _reduce_nums(ops): - count = 0 - for op in ops: - if op.prim.startswith('Reduce'): - count += 1 - return count - def _reduce_output(dom): if dom.pattern != PrimLib.REDUCE: return None - if _reduce_nums(dom.ops) > 1: + if reduce_nums(dom.ops) > 1: return None if _is_atomic_add_available(dom): return None - is_all_reduce = _tensor_size(dom.ops[0].output) == 1 + is_all_reduce = tensor_size(dom.ops[0].output) == 1 # excluded large size all reduce - if is_all_reduce and _tensor_size(dom.ops[0].inputs[0]) > 1024 * 12: + if is_all_reduce and tensor_size(dom.ops[0].inputs[0]) > 1024 * 12: return None fused = [] @@ -765,52 +825,17 @@ class GraphSplitGpu(GraphSplitByPattern): fused.append(a) return fused, False - def _stitch_axis(shape): - stitch_axis = [] - size = 1 - for i in shape: - size = size * i - stitch_axis.append(i) - if size >= 1024 * 8: - return stitch_axis - return None - - def _same_stitch_axis(a, b): - x = [] - x.extend(a) - x.extend(b) - stitch_axis = _stitch_axis(x[0].shape) - for item in x: - i_stitch_axis = _stitch_axis(item.shape) - if i_stitch_axis is None or i_stitch_axis != stitch_axis: - return False - return True - - def _may_stitch(dom, a, r): - if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a): - if _reduce_nums(a.ops) >= 2: - return False - dom_outs = [op.output for op in dom.ops] - a_ins = [op_input for op in a.ops for op_input in op.inputs] - a_outs = [op.output for op in a.ops] - a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins] - stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins] - if not _same_stitch_axis(stitch_tensors, a_final_outs): - return False - return any([_tensor_size(tensor) >= 1024 * 1024 for tensor in stitch_tensors]) - return False - def _reduce_stitch(dom): if dom.pattern != PrimLib.REDUCE: return None - if _tensor_size(dom.ops[0].output) == 1: + if tensor_size(dom.ops[0].output) == 1: return None - if _tensor_size(dom.ops[0].inputs[0]) < 1024 * 12: + if tensor_size(dom.ops[0].inputs[0]) < 1024 * 12: return None fused = [] for a, r in dom.out_relations.items(): - if not _may_stitch(dom, a, r): + if not may_stitch(dom, a, r, 1024 * 8, 1024 * 1024): continue if a.pattern == PrimLib.REDUCE: if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']: @@ -881,7 +906,7 @@ class GraphSplitGpu(GraphSplitByPattern): appected_areas = {"TensorScatterAdd", "UnsortedSegmentSum"} for a, _ in dom.out_relations.items(): if _shape_consistent(gather_prims, appected_areas, dom, a) and \ - _count_target_prim(a.ops + dom.ops, appected_areas) < 2 and dom.check_acyclic(a): + _count_target_prim(a.ops + dom.ops, appected_areas) < 2 and dom.check_acyclic(a): return [a], False return None @@ -916,8 +941,8 @@ class GraphSplitGpu(GraphSplitByPattern): for a in sibling: op = a.ops[0] if a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \ - PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \ - dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis"): + PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \ + dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis"): fused.append(a) return fused @@ -927,7 +952,7 @@ class GraphSplitGpu(GraphSplitByPattern): fused = [] for a in sibling: if a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \ - dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape: + dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape: fused.append(a) return fused @@ -945,7 +970,7 @@ class GraphSplitGpu(GraphSplitByPattern): changed = self.fuse(_broadcast_opaque) or changed changed = self.fuse(_gather_output) or changed changed = self.fuse(_reduce_output) or changed - if enable_stitch_fusion: + if self.enable_stitch_fusion: changed = self.fuse(_reduce_stitch) or changed self.fuse(_transpose) self.hfuse(_h_broadcast) @@ -957,12 +982,11 @@ class GraphSplitGpu(GraphSplitByPattern): fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \ fuse_func(_broadcast_width): return - if fuse_func(_reduce_output) or (enable_stitch_fusion and fuse_func(_reduce_stitch)): + if fuse_func(_reduce_output) or (self.enable_stitch_fusion and fuse_func(_reduce_stitch)): return fuse_func(_transpose) return - enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False) if fuse_func is None: _fuse_loop() else: @@ -976,6 +1000,7 @@ class GraphSplitAscend(GraphSplitByPattern): def get_default_mode(self, op): """Get efault mode for Ascend""" + def _dtype_same(tensors): dtype = tensors[0].dtype for tensor_ in tensors: @@ -992,15 +1017,10 @@ class GraphSplitAscend(GraphSplitByPattern): def pattern_fuse(self, fuse_func=None): """fuse Areas by pattern""" - def _tensor_size(tensor): - size = 1 - for i in tensor.shape: - size *= i - return size def _likely_multicore(dom): op = dom.dom_op() - iter_size = _tensor_size(op.output if not PrimLib.is_reduce(op) else op.inputs[0]) + iter_size = tensor_size(op.output if not PrimLib.is_reduce(op) else op.inputs[0]) return iter_size > 1024 def _reshape(dom): @@ -1013,8 +1033,8 @@ class GraphSplitAscend(GraphSplitByPattern): min_area = a for a, _ in dom.in_relations.items(): if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \ - len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \ - (min_area is None or a.pattern < min_area.pattern): + len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \ + (min_area is None or a.pattern < min_area.pattern): min_area, forward_fuse = a, True return ([min_area], forward_fuse) if min_area else None @@ -1111,6 +1131,27 @@ class GraphSplitAscend(GraphSplitByPattern): fused.append(a) return fused, False + def _reduce_stitch(dom): + if dom.pattern != PrimLib.REDUCE: + return None + if tensor_size(dom.ops[0].output) == 1: + return None + if tensor_size(dom.ops[0].inputs[0]) < 32 * 16 * 16: + return None + + fused = [] + for a, r in dom.out_relations.items(): + if not may_stitch(dom, a, r, 32, 32 * 16 * 16): + continue + if a.pattern == PrimLib.REDUCE: + if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']: + dom.stitch_info.stitch_ops.add(dom.ops[0].output.name) + fused.append(a) + elif a.pattern <= PrimLib.BROADCAST: + dom.stitch_info.stitch_ops.add(dom.ops[0].output.name) + fused.append(a) + return fused, False + def _transdata_pattern_support(dom, a): transdata_op = dom.dom_op() @@ -1127,6 +1168,7 @@ class GraphSplitAscend(GraphSplitByPattern): if dim % cube_size != 0: res = True return res + has_pad = _has_pad() if has_pad: return False @@ -1145,7 +1187,7 @@ class GraphSplitAscend(GraphSplitByPattern): if src_format == DF.FRAC_NZ and dst_format in (DF.DEFAULT, DF.NCHW): return True # For the Default/NCHW to FRAC_NZ, currently only the Cast+Transdata is supported - if src_format in (DF.DEFAULT, DF.NCHW) and dst_format == DF.FRAC_NZ\ + if src_format in (DF.DEFAULT, DF.NCHW) and dst_format == DF.FRAC_NZ \ and len(a.ops) == 1 and a.dom_op().prim == "Cast" and not a.is_output: return True return False @@ -1171,6 +1213,8 @@ class GraphSplitAscend(GraphSplitByPattern): changed = self.fuse(_broadcast_width) or changed changed = self.fuse(_matmul_depth) or changed changed = self.fuse(_reduce_output) or changed + if self.enable_stitch_fusion: + changed = self.fuse(_reduce_stitch) or changed self.fuse(_transdata) def _fuse_once(fuse_func): diff --git a/mindspore/_extends/graph_kernel/model/op_infer.py b/mindspore/_extends/graph_kernel/model/op_infer.py index 760f6586efe..435917e929e 100644 --- a/mindspore/_extends/graph_kernel/model/op_infer.py +++ b/mindspore/_extends/graph_kernel/model/op_infer.py @@ -101,6 +101,7 @@ class OpInfer: class _Elemwise(OpInfer): """Common infer for elementwise operators""" + @staticmethod def broadcast_shape(shapes): """deduce broadcast shape using same rules as numpy""" @@ -120,25 +121,24 @@ class _Elemwise(OpInfer): @staticmethod def defaultformat_to_nz(default_shape): """default format shape to fractal_Nz format shape""" - if len(default_shape) not in (1, 2): - raise GKException("shape is too long!") + more_two_d_shape, two_d_shape = default_shape[:-2], default_shape[-2:] # (32) or (1, 32) -> (2, 1, 1, 16) - if len(default_shape) == 1 or (len(default_shape) == 2 and default_shape[0] == 1): - shape = [default_shape[-1] // 16, 1, 1, 16] - if default_shape[-1] % 16 != 0: + if len(two_d_shape) == 1 or (len(two_d_shape) == 2 and two_d_shape[0] == 1): + shape = [two_d_shape[-1] // 16, 1, 1, 16] + if two_d_shape[-1] % 16 != 0: raise GKException("should be multiplies of 16") return shape # (32, 1) -> (1, 2, 16, 1) - if len(default_shape) == 2 and default_shape[1] == 1: - shape = [1, default_shape[0] // 16, 16, 1] - if default_shape[0] % 16 != 0: + if len(two_d_shape) == 2 and two_d_shape[1] == 1: + shape = [1, two_d_shape[0] // 16, 16, 1] + if two_d_shape[0] % 16 != 0: raise GKException("should be multiples of 16") return shape # (32, 48) -> (3, 2, 16, 16) - shape = [default_shape[1] // 16, default_shape[0] // 16, 16, 16] - if default_shape[0] % 16 != 0 or default_shape[1] % 16 != 0: + shape = [two_d_shape[1] // 16, two_d_shape[0] // 16, 16, 16] + if two_d_shape[0] % 16 != 0 or two_d_shape[1] % 16 != 0: raise GKException("should be multiples of 16") - return shape + return more_two_d_shape + shape def _infer_shape(self): """returns the output shape with broadcast"""