!22677 [GraphKernel] ascend support stitch fusion
Merge pull request !22677 from r1chardf1d0/master
This commit is contained in:
commit
315431b9a0
|
@ -19,8 +19,67 @@ from .model import PrimLib, Graph, Tensor, Operator
|
||||||
from .model import DataFormat as DF
|
from .model import DataFormat as DF
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_size(tensor):
|
||||||
|
"""get tensor size"""
|
||||||
|
size = 1
|
||||||
|
for i in tensor.shape:
|
||||||
|
size *= i
|
||||||
|
return size
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_nums(ops):
|
||||||
|
"""get reduce nums"""
|
||||||
|
count = 0
|
||||||
|
for op in ops:
|
||||||
|
if op.prim.startswith('Reduce'):
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
def may_stitch(dom, a, r, stitch_axis_size, stitch_buffer_size):
|
||||||
|
"""check if can stitch"""
|
||||||
|
|
||||||
|
def _same_stitch_axis(stitch_tensors, final_outs, stitch_axis_size):
|
||||||
|
"""does a and b have same stitch axis"""
|
||||||
|
|
||||||
|
def _stitch_axis(shape, stitch_axis_size):
|
||||||
|
"""get stitch axis"""
|
||||||
|
stitchaxis = []
|
||||||
|
size = 1
|
||||||
|
for i in shape:
|
||||||
|
size = size * i
|
||||||
|
stitchaxis.append(i)
|
||||||
|
if size >= stitch_axis_size:
|
||||||
|
return stitchaxis
|
||||||
|
return None
|
||||||
|
|
||||||
|
x = []
|
||||||
|
x.extend(stitch_tensors)
|
||||||
|
x.extend(final_outs)
|
||||||
|
stitch_axis_0 = _stitch_axis(x[0].shape, stitch_axis_size)
|
||||||
|
for item in x:
|
||||||
|
i_stitch_axis = _stitch_axis(item.shape, stitch_axis_size)
|
||||||
|
if i_stitch_axis is None or i_stitch_axis != stitch_axis_0:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
|
||||||
|
if reduce_nums(a.ops) >= 2:
|
||||||
|
return False
|
||||||
|
dom_outs = [op.output for op in dom.ops]
|
||||||
|
a_ins = [op_input for op in a.ops for op_input in op.inputs]
|
||||||
|
a_outs = [op.output for op in a.ops]
|
||||||
|
a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins]
|
||||||
|
stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins]
|
||||||
|
if not _same_stitch_axis(stitch_tensors, a_final_outs, stitch_axis_size):
|
||||||
|
return False
|
||||||
|
return any([tensor_size(tensor) >= stitch_buffer_size for tensor in stitch_tensors])
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class GraphSplitByPattern:
|
class GraphSplitByPattern:
|
||||||
"""Graph splitter"""
|
"""Graph splitter"""
|
||||||
|
|
||||||
class ReachTable:
|
class ReachTable:
|
||||||
"""Reachable table"""
|
"""Reachable table"""
|
||||||
|
|
||||||
|
@ -94,6 +153,7 @@ class GraphSplitByPattern:
|
||||||
self.output_excluded.add(to)
|
self.output_excluded.add(to)
|
||||||
else:
|
else:
|
||||||
_gather_reduce_exclude(to)
|
_gather_reduce_exclude(to)
|
||||||
|
|
||||||
_gather_reduce_exclude(init_op)
|
_gather_reduce_exclude(init_op)
|
||||||
self.unique_id = unique_id
|
self.unique_id = unique_id
|
||||||
self.reach_tab = reach_tab
|
self.reach_tab = reach_tab
|
||||||
|
@ -138,6 +198,7 @@ class GraphSplitByPattern:
|
||||||
|
|
||||||
def fuse(self, area):
|
def fuse(self, area):
|
||||||
"""Fuse `area` to `self`"""
|
"""Fuse `area` to `self`"""
|
||||||
|
|
||||||
def _update_relation(relations, a, r):
|
def _update_relation(relations, a, r):
|
||||||
relations[a] = max(r, relations[a]) if a in relations else r
|
relations[a] = max(r, relations[a]) if a in relations else r
|
||||||
|
|
||||||
|
@ -243,6 +304,7 @@ class GraphSplitByPattern:
|
||||||
self.areas = []
|
self.areas = []
|
||||||
self.flags = flags
|
self.flags = flags
|
||||||
self.enable_recompute = self.flags.get("enable_recompute_fusion", False)
|
self.enable_recompute = self.flags.get("enable_recompute_fusion", False)
|
||||||
|
self.enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False)
|
||||||
self.reach_tab = self.ReachTable(len(graph.ops) + 1 if self.enable_recompute else len(graph.ops))
|
self.reach_tab = self.ReachTable(len(graph.ops) + 1 if self.enable_recompute else len(graph.ops))
|
||||||
self.area_map = {}
|
self.area_map = {}
|
||||||
_, outputs = graph.deduce_parameters()
|
_, outputs = graph.deduce_parameters()
|
||||||
|
@ -256,7 +318,7 @@ class GraphSplitByPattern:
|
||||||
self.set_area_map([op], a)
|
self.set_area_map([op], a)
|
||||||
for a in self.areas:
|
for a in self.areas:
|
||||||
a.link_input(self.area_map)
|
a.link_input(self.area_map)
|
||||||
for i in range(len(self.areas)-1, -1, -1):
|
for i in range(len(self.areas) - 1, -1, -1):
|
||||||
self.areas[i].link_output()
|
self.areas[i].link_output()
|
||||||
if self.enable_recompute:
|
if self.enable_recompute:
|
||||||
self.recom_area = self.Area(None, False, idx, self.reach_tab)
|
self.recom_area = self.Area(None, False, idx, self.reach_tab)
|
||||||
|
@ -295,6 +357,7 @@ class GraphSplitByPattern:
|
||||||
|
|
||||||
def fuse(self, selector):
|
def fuse(self, selector):
|
||||||
"""Fuse areas"""
|
"""Fuse areas"""
|
||||||
|
|
||||||
def _fuse_area():
|
def _fuse_area():
|
||||||
for dominant in self.areas:
|
for dominant in self.areas:
|
||||||
result = selector(dominant)
|
result = selector(dominant)
|
||||||
|
@ -327,6 +390,7 @@ class GraphSplitByPattern:
|
||||||
|
|
||||||
def hfuse(self, selector):
|
def hfuse(self, selector):
|
||||||
"""Fuse horizontal areas with same input tensor"""
|
"""Fuse horizontal areas with same input tensor"""
|
||||||
|
|
||||||
def _do_fuse(areas):
|
def _do_fuse(areas):
|
||||||
for i in range(len(areas) - 1):
|
for i in range(len(areas) - 1):
|
||||||
sibling = []
|
sibling = []
|
||||||
|
@ -349,6 +413,7 @@ class GraphSplitByPattern:
|
||||||
self.areas.remove(area)
|
self.areas.remove(area)
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
changed = False
|
changed = False
|
||||||
while True:
|
while True:
|
||||||
for dom in self.areas:
|
for dom in self.areas:
|
||||||
|
@ -426,6 +491,7 @@ class GraphSplitByPattern:
|
||||||
|
|
||||||
def split_output_reshapes(self):
|
def split_output_reshapes(self):
|
||||||
"""Force split the output Reshapes into other new area"""
|
"""Force split the output Reshapes into other new area"""
|
||||||
|
|
||||||
def _remove_output_reshape(reshape_ops, other_ops):
|
def _remove_output_reshape(reshape_ops, other_ops):
|
||||||
def _run():
|
def _run():
|
||||||
for op in reshape_ops:
|
for op in reshape_ops:
|
||||||
|
@ -434,6 +500,7 @@ class GraphSplitByPattern:
|
||||||
other_ops.append(op)
|
other_ops.append(op)
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
while _run():
|
while _run():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -511,6 +578,7 @@ class GraphSplitByPattern:
|
||||||
|
|
||||||
def find_cheap_regions(self, dom):
|
def find_cheap_regions(self, dom):
|
||||||
"""extract all the cheap regions in dom area, toposort each region before return"""
|
"""extract all the cheap regions in dom area, toposort each region before return"""
|
||||||
|
|
||||||
def _grow_region(region_ops, op, weight, inputs):
|
def _grow_region(region_ops, op, weight, inputs):
|
||||||
"""include op to region_ops if region grow"""
|
"""include op to region_ops if region grow"""
|
||||||
# region successfully ends at inputs
|
# region successfully ends at inputs
|
||||||
|
@ -559,6 +627,7 @@ class GraphSplitByPattern:
|
||||||
|
|
||||||
def select_user_area(self, tail_tensor):
|
def select_user_area(self, tail_tensor):
|
||||||
"""select the user area has only one edge to dom area"""
|
"""select the user area has only one edge to dom area"""
|
||||||
|
|
||||||
def _get_edge_num(dom_area, user_area):
|
def _get_edge_num(dom_area, user_area):
|
||||||
"""get edge num between two areas"""
|
"""get edge num between two areas"""
|
||||||
dom_graph = self.to_subgraph(dom_area)
|
dom_graph = self.to_subgraph(dom_area)
|
||||||
|
@ -583,8 +652,10 @@ class GraphSplitByPattern:
|
||||||
|
|
||||||
def recompute_fuse(self):
|
def recompute_fuse(self):
|
||||||
"""find recompute regions and copy them out to new Areas"""
|
"""find recompute regions and copy them out to new Areas"""
|
||||||
|
|
||||||
def do_recompute_fuse():
|
def do_recompute_fuse():
|
||||||
"""split the unfusing pattern by add recompute area"""
|
"""split the unfusing pattern by add recompute area"""
|
||||||
|
|
||||||
def recompute_cheap_region(dom):
|
def recompute_cheap_region(dom):
|
||||||
for cheap_region in cheap_regions:
|
for cheap_region in cheap_regions:
|
||||||
user_areas = self.select_user_area(cheap_region[-1].output)
|
user_areas = self.select_user_area(cheap_region[-1].output)
|
||||||
|
@ -597,6 +668,7 @@ class GraphSplitByPattern:
|
||||||
if self.recom_res:
|
if self.recom_res:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
recompute_suc = False
|
recompute_suc = False
|
||||||
orig_areas = []
|
orig_areas = []
|
||||||
orig_areas.extend(self.areas)
|
orig_areas.extend(self.areas)
|
||||||
|
@ -628,6 +700,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
||||||
|
|
||||||
def pattern_fuse(self, fuse_func=None):
|
def pattern_fuse(self, fuse_func=None):
|
||||||
"""fuse Areas by pattern"""
|
"""fuse Areas by pattern"""
|
||||||
|
|
||||||
def _reshape(dom):
|
def _reshape(dom):
|
||||||
if dom.pattern != PrimLib.RESHAPE:
|
if dom.pattern != PrimLib.RESHAPE:
|
||||||
return None
|
return None
|
||||||
|
@ -638,8 +711,8 @@ class GraphSplitGpu(GraphSplitByPattern):
|
||||||
min_area = a
|
min_area = a
|
||||||
for a, _ in dom.in_relations.items():
|
for a, _ in dom.in_relations.items():
|
||||||
if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \
|
if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \
|
||||||
len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \
|
len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \
|
||||||
(min_area is None or a.pattern < min_area.pattern):
|
(min_area is None or a.pattern < min_area.pattern):
|
||||||
min_area, forward_fuse = a, True
|
min_area, forward_fuse = a, True
|
||||||
return ([min_area], forward_fuse) if min_area else None
|
return ([min_area], forward_fuse) if min_area else None
|
||||||
|
|
||||||
|
@ -718,12 +791,6 @@ class GraphSplitGpu(GraphSplitByPattern):
|
||||||
fused.append(a)
|
fused.append(a)
|
||||||
return fused, True
|
return fused, True
|
||||||
|
|
||||||
def _tensor_size(tensor):
|
|
||||||
size = 1
|
|
||||||
for i in tensor.shape:
|
|
||||||
size *= i
|
|
||||||
return size
|
|
||||||
|
|
||||||
def _is_atomic_add_available(dom):
|
def _is_atomic_add_available(dom):
|
||||||
if any(["Reduce" in x.prim for x in dom.ops[1:]]):
|
if any(["Reduce" in x.prim for x in dom.ops[1:]]):
|
||||||
return False
|
return False
|
||||||
|
@ -739,23 +806,16 @@ class GraphSplitGpu(GraphSplitByPattern):
|
||||||
return reduce_size >= 1024
|
return reduce_size >= 1024
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _reduce_nums(ops):
|
|
||||||
count = 0
|
|
||||||
for op in ops:
|
|
||||||
if op.prim.startswith('Reduce'):
|
|
||||||
count += 1
|
|
||||||
return count
|
|
||||||
|
|
||||||
def _reduce_output(dom):
|
def _reduce_output(dom):
|
||||||
if dom.pattern != PrimLib.REDUCE:
|
if dom.pattern != PrimLib.REDUCE:
|
||||||
return None
|
return None
|
||||||
if _reduce_nums(dom.ops) > 1:
|
if reduce_nums(dom.ops) > 1:
|
||||||
return None
|
return None
|
||||||
if _is_atomic_add_available(dom):
|
if _is_atomic_add_available(dom):
|
||||||
return None
|
return None
|
||||||
is_all_reduce = _tensor_size(dom.ops[0].output) == 1
|
is_all_reduce = tensor_size(dom.ops[0].output) == 1
|
||||||
# excluded large size all reduce
|
# excluded large size all reduce
|
||||||
if is_all_reduce and _tensor_size(dom.ops[0].inputs[0]) > 1024 * 12:
|
if is_all_reduce and tensor_size(dom.ops[0].inputs[0]) > 1024 * 12:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
fused = []
|
fused = []
|
||||||
|
@ -765,52 +825,17 @@ class GraphSplitGpu(GraphSplitByPattern):
|
||||||
fused.append(a)
|
fused.append(a)
|
||||||
return fused, False
|
return fused, False
|
||||||
|
|
||||||
def _stitch_axis(shape):
|
|
||||||
stitch_axis = []
|
|
||||||
size = 1
|
|
||||||
for i in shape:
|
|
||||||
size = size * i
|
|
||||||
stitch_axis.append(i)
|
|
||||||
if size >= 1024 * 8:
|
|
||||||
return stitch_axis
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _same_stitch_axis(a, b):
|
|
||||||
x = []
|
|
||||||
x.extend(a)
|
|
||||||
x.extend(b)
|
|
||||||
stitch_axis = _stitch_axis(x[0].shape)
|
|
||||||
for item in x:
|
|
||||||
i_stitch_axis = _stitch_axis(item.shape)
|
|
||||||
if i_stitch_axis is None or i_stitch_axis != stitch_axis:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _may_stitch(dom, a, r):
|
|
||||||
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
|
|
||||||
if _reduce_nums(a.ops) >= 2:
|
|
||||||
return False
|
|
||||||
dom_outs = [op.output for op in dom.ops]
|
|
||||||
a_ins = [op_input for op in a.ops for op_input in op.inputs]
|
|
||||||
a_outs = [op.output for op in a.ops]
|
|
||||||
a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins]
|
|
||||||
stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins]
|
|
||||||
if not _same_stitch_axis(stitch_tensors, a_final_outs):
|
|
||||||
return False
|
|
||||||
return any([_tensor_size(tensor) >= 1024 * 1024 for tensor in stitch_tensors])
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _reduce_stitch(dom):
|
def _reduce_stitch(dom):
|
||||||
if dom.pattern != PrimLib.REDUCE:
|
if dom.pattern != PrimLib.REDUCE:
|
||||||
return None
|
return None
|
||||||
if _tensor_size(dom.ops[0].output) == 1:
|
if tensor_size(dom.ops[0].output) == 1:
|
||||||
return None
|
return None
|
||||||
if _tensor_size(dom.ops[0].inputs[0]) < 1024 * 12:
|
if tensor_size(dom.ops[0].inputs[0]) < 1024 * 12:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
fused = []
|
fused = []
|
||||||
for a, r in dom.out_relations.items():
|
for a, r in dom.out_relations.items():
|
||||||
if not _may_stitch(dom, a, r):
|
if not may_stitch(dom, a, r, 1024 * 8, 1024 * 1024):
|
||||||
continue
|
continue
|
||||||
if a.pattern == PrimLib.REDUCE:
|
if a.pattern == PrimLib.REDUCE:
|
||||||
if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']:
|
if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']:
|
||||||
|
@ -881,7 +906,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
||||||
appected_areas = {"TensorScatterAdd", "UnsortedSegmentSum"}
|
appected_areas = {"TensorScatterAdd", "UnsortedSegmentSum"}
|
||||||
for a, _ in dom.out_relations.items():
|
for a, _ in dom.out_relations.items():
|
||||||
if _shape_consistent(gather_prims, appected_areas, dom, a) and \
|
if _shape_consistent(gather_prims, appected_areas, dom, a) and \
|
||||||
_count_target_prim(a.ops + dom.ops, appected_areas) < 2 and dom.check_acyclic(a):
|
_count_target_prim(a.ops + dom.ops, appected_areas) < 2 and dom.check_acyclic(a):
|
||||||
return [a], False
|
return [a], False
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -916,8 +941,8 @@ class GraphSplitGpu(GraphSplitByPattern):
|
||||||
for a in sibling:
|
for a in sibling:
|
||||||
op = a.ops[0]
|
op = a.ops[0]
|
||||||
if a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \
|
if a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \
|
||||||
PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \
|
PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \
|
||||||
dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis"):
|
dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis"):
|
||||||
fused.append(a)
|
fused.append(a)
|
||||||
return fused
|
return fused
|
||||||
|
|
||||||
|
@ -927,7 +952,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
||||||
fused = []
|
fused = []
|
||||||
for a in sibling:
|
for a in sibling:
|
||||||
if a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \
|
if a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \
|
||||||
dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape:
|
dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape:
|
||||||
fused.append(a)
|
fused.append(a)
|
||||||
return fused
|
return fused
|
||||||
|
|
||||||
|
@ -945,7 +970,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
||||||
changed = self.fuse(_broadcast_opaque) or changed
|
changed = self.fuse(_broadcast_opaque) or changed
|
||||||
changed = self.fuse(_gather_output) or changed
|
changed = self.fuse(_gather_output) or changed
|
||||||
changed = self.fuse(_reduce_output) or changed
|
changed = self.fuse(_reduce_output) or changed
|
||||||
if enable_stitch_fusion:
|
if self.enable_stitch_fusion:
|
||||||
changed = self.fuse(_reduce_stitch) or changed
|
changed = self.fuse(_reduce_stitch) or changed
|
||||||
self.fuse(_transpose)
|
self.fuse(_transpose)
|
||||||
self.hfuse(_h_broadcast)
|
self.hfuse(_h_broadcast)
|
||||||
|
@ -957,12 +982,11 @@ class GraphSplitGpu(GraphSplitByPattern):
|
||||||
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
|
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
|
||||||
fuse_func(_broadcast_width):
|
fuse_func(_broadcast_width):
|
||||||
return
|
return
|
||||||
if fuse_func(_reduce_output) or (enable_stitch_fusion and fuse_func(_reduce_stitch)):
|
if fuse_func(_reduce_output) or (self.enable_stitch_fusion and fuse_func(_reduce_stitch)):
|
||||||
return
|
return
|
||||||
fuse_func(_transpose)
|
fuse_func(_transpose)
|
||||||
return
|
return
|
||||||
|
|
||||||
enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False)
|
|
||||||
if fuse_func is None:
|
if fuse_func is None:
|
||||||
_fuse_loop()
|
_fuse_loop()
|
||||||
else:
|
else:
|
||||||
|
@ -976,6 +1000,7 @@ class GraphSplitAscend(GraphSplitByPattern):
|
||||||
|
|
||||||
def get_default_mode(self, op):
|
def get_default_mode(self, op):
|
||||||
"""Get efault mode for Ascend"""
|
"""Get efault mode for Ascend"""
|
||||||
|
|
||||||
def _dtype_same(tensors):
|
def _dtype_same(tensors):
|
||||||
dtype = tensors[0].dtype
|
dtype = tensors[0].dtype
|
||||||
for tensor_ in tensors:
|
for tensor_ in tensors:
|
||||||
|
@ -992,15 +1017,10 @@ class GraphSplitAscend(GraphSplitByPattern):
|
||||||
|
|
||||||
def pattern_fuse(self, fuse_func=None):
|
def pattern_fuse(self, fuse_func=None):
|
||||||
"""fuse Areas by pattern"""
|
"""fuse Areas by pattern"""
|
||||||
def _tensor_size(tensor):
|
|
||||||
size = 1
|
|
||||||
for i in tensor.shape:
|
|
||||||
size *= i
|
|
||||||
return size
|
|
||||||
|
|
||||||
def _likely_multicore(dom):
|
def _likely_multicore(dom):
|
||||||
op = dom.dom_op()
|
op = dom.dom_op()
|
||||||
iter_size = _tensor_size(op.output if not PrimLib.is_reduce(op) else op.inputs[0])
|
iter_size = tensor_size(op.output if not PrimLib.is_reduce(op) else op.inputs[0])
|
||||||
return iter_size > 1024
|
return iter_size > 1024
|
||||||
|
|
||||||
def _reshape(dom):
|
def _reshape(dom):
|
||||||
|
@ -1013,8 +1033,8 @@ class GraphSplitAscend(GraphSplitByPattern):
|
||||||
min_area = a
|
min_area = a
|
||||||
for a, _ in dom.in_relations.items():
|
for a, _ in dom.in_relations.items():
|
||||||
if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \
|
if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \
|
||||||
len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \
|
len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \
|
||||||
(min_area is None or a.pattern < min_area.pattern):
|
(min_area is None or a.pattern < min_area.pattern):
|
||||||
min_area, forward_fuse = a, True
|
min_area, forward_fuse = a, True
|
||||||
return ([min_area], forward_fuse) if min_area else None
|
return ([min_area], forward_fuse) if min_area else None
|
||||||
|
|
||||||
|
@ -1111,6 +1131,27 @@ class GraphSplitAscend(GraphSplitByPattern):
|
||||||
fused.append(a)
|
fused.append(a)
|
||||||
return fused, False
|
return fused, False
|
||||||
|
|
||||||
|
def _reduce_stitch(dom):
|
||||||
|
if dom.pattern != PrimLib.REDUCE:
|
||||||
|
return None
|
||||||
|
if tensor_size(dom.ops[0].output) == 1:
|
||||||
|
return None
|
||||||
|
if tensor_size(dom.ops[0].inputs[0]) < 32 * 16 * 16:
|
||||||
|
return None
|
||||||
|
|
||||||
|
fused = []
|
||||||
|
for a, r in dom.out_relations.items():
|
||||||
|
if not may_stitch(dom, a, r, 32, 32 * 16 * 16):
|
||||||
|
continue
|
||||||
|
if a.pattern == PrimLib.REDUCE:
|
||||||
|
if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']:
|
||||||
|
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
|
||||||
|
fused.append(a)
|
||||||
|
elif a.pattern <= PrimLib.BROADCAST:
|
||||||
|
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
|
||||||
|
fused.append(a)
|
||||||
|
return fused, False
|
||||||
|
|
||||||
def _transdata_pattern_support(dom, a):
|
def _transdata_pattern_support(dom, a):
|
||||||
transdata_op = dom.dom_op()
|
transdata_op = dom.dom_op()
|
||||||
|
|
||||||
|
@ -1127,6 +1168,7 @@ class GraphSplitAscend(GraphSplitByPattern):
|
||||||
if dim % cube_size != 0:
|
if dim % cube_size != 0:
|
||||||
res = True
|
res = True
|
||||||
return res
|
return res
|
||||||
|
|
||||||
has_pad = _has_pad()
|
has_pad = _has_pad()
|
||||||
if has_pad:
|
if has_pad:
|
||||||
return False
|
return False
|
||||||
|
@ -1145,7 +1187,7 @@ class GraphSplitAscend(GraphSplitByPattern):
|
||||||
if src_format == DF.FRAC_NZ and dst_format in (DF.DEFAULT, DF.NCHW):
|
if src_format == DF.FRAC_NZ and dst_format in (DF.DEFAULT, DF.NCHW):
|
||||||
return True
|
return True
|
||||||
# For the Default/NCHW to FRAC_NZ, currently only the Cast+Transdata is supported
|
# For the Default/NCHW to FRAC_NZ, currently only the Cast+Transdata is supported
|
||||||
if src_format in (DF.DEFAULT, DF.NCHW) and dst_format == DF.FRAC_NZ\
|
if src_format in (DF.DEFAULT, DF.NCHW) and dst_format == DF.FRAC_NZ \
|
||||||
and len(a.ops) == 1 and a.dom_op().prim == "Cast" and not a.is_output:
|
and len(a.ops) == 1 and a.dom_op().prim == "Cast" and not a.is_output:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
@ -1171,6 +1213,8 @@ class GraphSplitAscend(GraphSplitByPattern):
|
||||||
changed = self.fuse(_broadcast_width) or changed
|
changed = self.fuse(_broadcast_width) or changed
|
||||||
changed = self.fuse(_matmul_depth) or changed
|
changed = self.fuse(_matmul_depth) or changed
|
||||||
changed = self.fuse(_reduce_output) or changed
|
changed = self.fuse(_reduce_output) or changed
|
||||||
|
if self.enable_stitch_fusion:
|
||||||
|
changed = self.fuse(_reduce_stitch) or changed
|
||||||
self.fuse(_transdata)
|
self.fuse(_transdata)
|
||||||
|
|
||||||
def _fuse_once(fuse_func):
|
def _fuse_once(fuse_func):
|
||||||
|
|
|
@ -101,6 +101,7 @@ class OpInfer:
|
||||||
|
|
||||||
class _Elemwise(OpInfer):
|
class _Elemwise(OpInfer):
|
||||||
"""Common infer for elementwise operators"""
|
"""Common infer for elementwise operators"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def broadcast_shape(shapes):
|
def broadcast_shape(shapes):
|
||||||
"""deduce broadcast shape using same rules as numpy"""
|
"""deduce broadcast shape using same rules as numpy"""
|
||||||
|
@ -120,25 +121,24 @@ class _Elemwise(OpInfer):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def defaultformat_to_nz(default_shape):
|
def defaultformat_to_nz(default_shape):
|
||||||
"""default format shape to fractal_Nz format shape"""
|
"""default format shape to fractal_Nz format shape"""
|
||||||
if len(default_shape) not in (1, 2):
|
more_two_d_shape, two_d_shape = default_shape[:-2], default_shape[-2:]
|
||||||
raise GKException("shape is too long!")
|
|
||||||
# (32) or (1, 32) -> (2, 1, 1, 16)
|
# (32) or (1, 32) -> (2, 1, 1, 16)
|
||||||
if len(default_shape) == 1 or (len(default_shape) == 2 and default_shape[0] == 1):
|
if len(two_d_shape) == 1 or (len(two_d_shape) == 2 and two_d_shape[0] == 1):
|
||||||
shape = [default_shape[-1] // 16, 1, 1, 16]
|
shape = [two_d_shape[-1] // 16, 1, 1, 16]
|
||||||
if default_shape[-1] % 16 != 0:
|
if two_d_shape[-1] % 16 != 0:
|
||||||
raise GKException("should be multiplies of 16")
|
raise GKException("should be multiplies of 16")
|
||||||
return shape
|
return shape
|
||||||
# (32, 1) -> (1, 2, 16, 1)
|
# (32, 1) -> (1, 2, 16, 1)
|
||||||
if len(default_shape) == 2 and default_shape[1] == 1:
|
if len(two_d_shape) == 2 and two_d_shape[1] == 1:
|
||||||
shape = [1, default_shape[0] // 16, 16, 1]
|
shape = [1, two_d_shape[0] // 16, 16, 1]
|
||||||
if default_shape[0] % 16 != 0:
|
if two_d_shape[0] % 16 != 0:
|
||||||
raise GKException("should be multiples of 16")
|
raise GKException("should be multiples of 16")
|
||||||
return shape
|
return shape
|
||||||
# (32, 48) -> (3, 2, 16, 16)
|
# (32, 48) -> (3, 2, 16, 16)
|
||||||
shape = [default_shape[1] // 16, default_shape[0] // 16, 16, 16]
|
shape = [two_d_shape[1] // 16, two_d_shape[0] // 16, 16, 16]
|
||||||
if default_shape[0] % 16 != 0 or default_shape[1] % 16 != 0:
|
if two_d_shape[0] % 16 != 0 or two_d_shape[1] % 16 != 0:
|
||||||
raise GKException("should be multiples of 16")
|
raise GKException("should be multiples of 16")
|
||||||
return shape
|
return more_two_d_shape + shape
|
||||||
|
|
||||||
def _infer_shape(self):
|
def _infer_shape(self):
|
||||||
"""returns the output shape with broadcast"""
|
"""returns the output shape with broadcast"""
|
||||||
|
|
Loading…
Reference in New Issue