fix pylint warnings of graph_split.py

the GraphSplitCpu is removed in this commit,
the environ var "MS_DEV_GRAPH_KERNEL_PY_SPLIT_MODEL" is removed also.
This commit is contained in:
dayschan 2022-03-22 14:44:21 +08:00
parent 56183ab741
commit a3e8f0e3cf
3 changed files with 92 additions and 180 deletions

View File

@ -232,7 +232,7 @@ class CostModelSplitSchemer : public SplitSchemer {
std::shared_ptr<SplitSchemer> GraphKernelSplitterWithPy::GetSplitSchema(const std::string &processor) {
// default use c++ split model for CPU target.
if (processor != kCPUDevice || common::GetEnv("MS_DEV_GRAPH_KERNEL_PY_SPLIT_MODEL") == "on") {
if (processor != kCPUDevice) {
MS_LOG(DEBUG) << "use py split model";
return std::make_shared<CostModelSplitSchemer>();
} else {

View File

@ -50,10 +50,9 @@ class COMMON_EXPORT GraphKernelFlags {
GraphKernelFlags(const GraphKernelFlags &flags) = delete;
GraphKernelFlags(GraphKernelFlags &&flags) = delete;
void operator=(const GraphKernelFlags &flags) = delete;
GraphKernelFlags &operator=(const GraphKernelFlags &flags) = delete;
~GraphKernelFlags() = default;
public:
/**
* Dump info as human-readable text.
* A directory "graph_kernel_dump" will be created, and all information will be dumped in this directory.

View File

@ -52,7 +52,7 @@ def may_stitch(dom, a, r, stitch_axis_size, stitch_buffer_size):
stitchaxis.append(i)
if size >= stitch_axis_size:
return stitchaxis
return None
return []
x = []
x.extend(stitch_tensors)
@ -60,21 +60,21 @@ def may_stitch(dom, a, r, stitch_axis_size, stitch_buffer_size):
stitch_axis_0 = _stitch_axis(x[0].shape, stitch_axis_size)
for item in x:
i_stitch_axis = _stitch_axis(item.shape, stitch_axis_size)
if i_stitch_axis is None or i_stitch_axis != stitch_axis_0:
if not i_stitch_axis or i_stitch_axis != stitch_axis_0:
return False
return True
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
if reduce_nums(a.ops) >= 2:
return False
dom_outs = [op.output for op in dom.ops]
a_ins = [op_input for op in a.ops for op_input in op.inputs]
a_outs = [op.output for op in a.ops]
a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins]
stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins]
dom_outs = set(op.output for op in dom.ops)
a_ins = set(op_input for op in a.ops for op_input in op.inputs)
a_outs = set(op.output for op in a.ops)
a_final_outs = list(tensor for tensor in a_outs if tensor not in a_ins)
stitch_tensors = list(tensor for tensor in dom_outs if tensor in a_ins)
if not _same_stitch_axis(stitch_tensors, a_final_outs, stitch_axis_size):
return False
return any([tensor_size(tensor) >= stitch_buffer_size for tensor in stitch_tensors])
return any((tensor_size(tensor) >= stitch_buffer_size for tensor in stitch_tensors))
return False
@ -85,7 +85,7 @@ class CommonPattern:
def reshape(dom):
"""fuse strategy for reshape dom"""
if dom.pattern != PrimLib.RESHAPE:
return None
return []
min_area, forward_fuse = None, False
for a, _ in dom.out_relations.items():
if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \
@ -96,24 +96,24 @@ class CommonPattern:
len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \
(min_area is None or a.pattern < min_area.pattern):
min_area, forward_fuse = a, True
return ([min_area], forward_fuse) if min_area else None
return ([min_area], forward_fuse) if min_area else []
@staticmethod
def elemwise_depth(dom):
"""fuse strategy in depth for elemwise dom"""
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1:
return None
return []
a, r = list(dom.in_relations.items())[0]
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r != PrimLib.ELEMWISE or \
a.dom_op().output.shape != dom.dom_op().output.shape:
return None
return []
return [a], True
@staticmethod
def elemwise_width(dom):
"""fuse strategy in width for elemwise dom"""
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
return None
return []
fused = []
for a, r in dom.in_relations.items():
if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_acyclic(dom) and \
@ -125,7 +125,7 @@ class CommonPattern:
def assign(dom):
"""fuse strategy for assign dom"""
if len(dom.ops) != 1 or dom.dom_op().prim != "Assign":
return None
return []
fused = []
for a, _ in dom.in_relations.items():
fused.append(a)
@ -142,7 +142,7 @@ class GraphSplitByPattern:
self.map = []
self.alive = set(range(size))
for i in range(0, size):
self.map.append([False for j in range(0, size)])
self.map.append([False] * size)
self.map[i][i] = True
def reachable(self, x, y):
@ -214,12 +214,13 @@ class GraphSplitByPattern:
self.reach_tab = reach_tab
def __str__(self):
return '<' + '-'.join([op.output.name for op in self.ops]) + '>'
return '<' + '-'.join((op.output.name for op in self.ops)) + '>'
def __repr__(self):
return str(self)
def get_relation(self, op, i):
@staticmethod
def get_relation(op, i):
"""Get op relation"""
relation = PrimLib.UNKNOWN
_, elem_relation = PrimLib.input_relation(op, i)
@ -260,8 +261,8 @@ class GraphSplitByPattern:
def _update_pattern():
if area.pattern > self.pattern:
self.pattern = area.pattern
if area in self.in_relations and self.in_relations[area] > self.pattern:
self.pattern = self.in_relations[area]
if area in self.in_relations and self.in_relations.get(area) > self.pattern:
self.pattern = self.in_relations.get(area)
def _fuse_relation(self_relations, new_relations):
for a, r in new_relations.items():
@ -333,8 +334,8 @@ class GraphSplitByPattern:
# copy ops
cp_ops = []
for op in area.recompute_ops:
inputs = [tensor_map[op.inputs[0]]] if op.inputs else []
cp_op = Operator(op.prim, inputs, tensor_map[op.output], op.attrs)
inputs = [tensor_map.get(op.inputs[0])] if op.inputs else []
cp_op = Operator(op.prim, inputs, tensor_map.get(op.output), op.attrs)
cp_op.all_inputs = cp_op.inputs
cp_ops.append(cp_op)
area.ori_op_map[cp_op] = op
@ -342,9 +343,9 @@ class GraphSplitByPattern:
for op in self.ops:
if tail_tensor in op.inputs:
op.inputs.remove(tail_tensor)
op.inputs.append(tensor_map[tail_tensor])
op.inputs.append(tensor_map.get(tail_tensor))
tail_tensor.to_ops.remove(op)
tensor_map[tail_tensor].to_ops.append(op)
tensor_map.get(tail_tensor).to_ops.append(op)
# fill cp_ops in self.recompute_area
cp_dom_op = None
for cp, ori in area.ori_op_map.items():
@ -352,7 +353,7 @@ class GraphSplitByPattern:
cp_dom_op = cp
area.ops.clear()
area.ops.append(cp_dom_op)
area.ops.extend([op for op in cp_ops if op != cp_dom_op])
area.ops.extend((op for op in cp_ops if op != cp_dom_op))
def __init__(self, graph, flags):
self.graph = graph
@ -395,7 +396,8 @@ class GraphSplitByPattern:
"""Set default mode"""
area.mode = self.get_default_mode(area.ops[0])
def limit_area_size(self, dominant, fuse_areas, limit_size=200):
@staticmethod
def limit_area_size(dominant, fuse_areas, limit_size=200):
"""Remove some areas if the size is too large"""
area_sizes = map(lambda area: len(area.ops), fuse_areas)
dom_size = len(dominant.ops)
@ -417,7 +419,7 @@ class GraphSplitByPattern:
def _fuse_area():
for dominant in self.areas:
result = selector(dominant)
if result is None or not result[0]:
if not result or not result[0]:
continue
fuse_areas, is_forward = result
fuse_areas = self.limit_area_size(dominant, fuse_areas)
@ -460,7 +462,7 @@ class GraphSplitByPattern:
def _update_areas(areas, from_op):
for op in from_op.to_ops:
a = self.area_map[op]
a = self.area_map.get(op)
if a in self.areas and a not in areas:
areas.append(a)
@ -488,7 +490,7 @@ class GraphSplitByPattern:
"""Fuse recompute area to its user"""
for dominant in [self.recom_area, self.recom_user]:
result = selector(dominant)
if result is not None and result[0]:
if result and result[0]:
fuse_areas, _ = result
fuse_areas = self.limit_area_size(dominant, fuse_areas)
if not fuse_areas:
@ -507,7 +509,7 @@ class GraphSplitByPattern:
ids[op] = i
if hasattr(self, 'orig_op_map'):
for k, v in self.orig_op_map.items():
ids[k] = ids[v]
ids[k] = ids.get(v)
return ids
def to_subgraphs(self):
@ -516,7 +518,7 @@ class GraphSplitByPattern:
subgraphs = []
graphmodes = []
for i, area in enumerate(self.areas):
area.ops.sort(key=lambda op: ids[op])
area.ops.sort(key=ids.get)
subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops, area.stitch_info, area.recompute_ops))
graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite")
return subgraphs, graphmodes
@ -542,7 +544,7 @@ class GraphSplitByPattern:
def _remove_output_reshape(reshape_ops, other_ops):
def _run():
for op in reshape_ops:
if any([to_op in other_ops for to_op in op.output.to_ops]):
if any((to_op in other_ops for to_op in op.output.to_ops)):
reshape_ops.remove(op)
other_ops.append(op)
return True
@ -553,8 +555,8 @@ class GraphSplitByPattern:
new_areas = []
for area in self.areas:
reshape_ops = [op for op in area.ops if PrimLib.iter_type(op) == PrimLib.RESHAPE]
other_ops = [op for op in area.ops if op not in reshape_ops]
reshape_ops = list(op for op in area.ops if PrimLib.iter_type(op) == PrimLib.RESHAPE)
other_ops = list(op for op in area.ops if op not in reshape_ops)
if not other_ops or not reshape_ops:
continue
# remove the output reshape from "reshape_ops" and add it into "other_ops"
@ -575,7 +577,7 @@ class GraphSplitByPattern:
"""set the recompute area and connect with other areas"""
self.recom_area.recompute_ops.extend(ops)
# recom_area: set dom_op and correct ops length
patterns = [PrimLib.iter_type(op) for op in ops]
patterns = list(PrimLib.iter_type(op) for op in ops)
self.recom_area.pattern = max(patterns)
for i, pat in enumerate(patterns):
if pat == self.recom_area.pattern:
@ -589,7 +591,7 @@ class GraphSplitByPattern:
user_area.in_relations[self.recom_area] = self.dom_user_r
self.recom_area.out_relations[user_area] = self.dom_user_r
# connect recom_pre and recom_area
self.recom_pre = self.area_map[ops[0].inputs[0].op] if ops[0].inputs and ops[0].inputs[0].op else None
self.recom_pre = self.area_map.get(ops[0].inputs[0].op) if ops[0].inputs and ops[0].inputs[0].op else None
if self.recom_pre is not None:
self.recom_area.in_relations[self.recom_pre] = dom_area.in_relations[self.recom_pre]
self.recom_pre.out_relations[self.recom_area] = dom_area.in_relations[self.recom_pre]
@ -618,7 +620,7 @@ class GraphSplitByPattern:
ids = self.index_op()
dom_ops = list()
dom_ops.extend(dom.ops)
dom_ops.sort(key=lambda op: ids[op])
dom_ops.sort(key=ids.get)
subgraph = []
subgraph = Graph('{}_area'.format(self.graph.name), dom_ops)
return subgraph
@ -631,15 +633,15 @@ class GraphSplitByPattern:
# region successfully ends at inputs
if not op.inputs:
region_ops.append(op)
return False, None, weight, True
return False, op, weight, True
if op.inputs[0] in inputs and len(op.inputs) == 1 and \
PrimLib.iter_type(op) <= PrimLib.BROADCAST:
region_ops.append(op)
return False, None, weight, True
return False, op, weight, True
# region fails to grow
max_weight = 20
if weight > max_weight or len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST:
return False, None, weight, False
return False, op, weight, False
# region grows successfully
weight = weight + 1
region_ops.append(op)
@ -681,16 +683,15 @@ class GraphSplitByPattern:
_, dom_outputs = dom_graph.deduce_parameters()
user_graph = self.to_subgraph(user_area)
user_inputs, _ = user_graph.deduce_parameters()
edge = [t for t in dom_outputs if t in user_inputs]
return len(edge)
return len(list(t for t in dom_outputs if t in user_inputs))
def _select_user_area(tail_tensor):
user_areas = []
for user_op in tail_tensor.to_ops:
user_area = self.area_map[user_op]
user_area = self.area_map.get(user_op)
if user_area.pattern == PrimLib.RESHAPE:
continue
edge_num = _get_edge_num(self.area_map[tail_tensor.op], user_area)
edge_num = _get_edge_num(self.area_map.get(tail_tensor.op), user_area)
if edge_num == 1 and not user_area in user_areas:
user_areas.append(user_area)
return user_areas
@ -757,21 +758,21 @@ class GraphSplitGpu(GraphSplitByPattern):
def _broadcast_depth(dom):
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \
dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
return None
return []
a, r = list(dom.out_relations.items())[0]
if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1:
return None
return []
return [a], False
def _broadcast_width(dom):
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \
dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
return None
return []
fused = []
for a, r in dom.out_relations.items():
if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a) or \
(fused and fused[0].dom_op().output.shape != a.dom_op().output.shape):
return None
return []
fused.append(a)
return fused, False
@ -782,19 +783,19 @@ class GraphSplitGpu(GraphSplitByPattern):
def _reduce_depth(dom):
if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1:
return None
return []
a, r = list(dom.in_relations.items())[0]
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
_is_atomic_add_available(dom):
# to evade the precision problem.
return None
return []
if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1:
return None
return []
return [a], True
def _reduce_width(dom):
if dom.pattern != PrimLib.REDUCE:
return None
return []
fused = []
for a, r in dom.in_relations.items():
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
@ -806,7 +807,7 @@ class GraphSplitGpu(GraphSplitByPattern):
return fused, True
def _is_atomic_add_available(dom):
if any(["Reduce" in x.prim for x in dom.ops[1:]]):
if any(("Reduce" in x.prim for x in dom.ops[1:])):
return False
op = dom.ops[0]
if "reduce_axis" in op.attrs:
@ -816,21 +817,21 @@ class GraphSplitGpu(GraphSplitByPattern):
else:
raise Exception("For '{}', can not find the attr 'reduce_axis' or 'axis'".format(op.prim))
if len(op.inputs[0].shape) - 1 in reduce_axis:
reduce_size = prod_reduce(lambda x, y: x * y, [op.inputs[0].shape[i] for i in reduce_axis])
reduce_size = prod_reduce(lambda x, y: x * y, (op.inputs[0].shape[i] for i in reduce_axis))
return reduce_size >= 1024
return True
def _reduce_output(dom):
if dom.pattern != PrimLib.REDUCE:
return None
return []
if reduce_nums(dom.ops) > 1:
return None
return []
if _is_atomic_add_available(dom):
return None
return []
is_all_reduce = tensor_size(dom.ops[0].output) == 1
# excluded large size all reduce
if is_all_reduce and tensor_size(dom.ops[0].inputs[0]) > 1024 * 12:
return None
return []
fused = []
for a, r in dom.out_relations.items():
@ -841,11 +842,11 @@ class GraphSplitGpu(GraphSplitByPattern):
def _reduce_stitch(dom):
if dom.pattern != PrimLib.REDUCE:
return None
return []
if tensor_size(dom.ops[0].output) == 1:
return None
return []
if tensor_size(dom.ops[0].inputs[0]) < 1024 * 12:
return None
return []
fused = []
for a, r in dom.out_relations.items():
@ -862,7 +863,7 @@ class GraphSplitGpu(GraphSplitByPattern):
def _transpose(dom):
if len(dom.ops) != 1 or dom.ops[0].prim != "Transpose":
return None
return []
fused = []
for a, _ in dom.in_relations.items():
if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and len(a.ops) <= self.TRANSPOSE_FUSE_DEPTH:
@ -871,7 +872,7 @@ class GraphSplitGpu(GraphSplitByPattern):
def _strided_slice(dom):
if dom.dom_op().prim != "StridedSlice":
return None
return []
fused = []
for a, _ in dom.in_relations.items():
if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \
@ -882,7 +883,7 @@ class GraphSplitGpu(GraphSplitByPattern):
def _gather_output(dom, reduce_fusion=False):
gather_prims = ("Gather", "GatherNd")
if not dom.dom_op().prim in gather_prims:
return None
return []
def _reduce_exclude(op, axis_list):
""" Whether this operator should be excluded.
@ -956,7 +957,7 @@ class GraphSplitGpu(GraphSplitByPattern):
start_ops.append(op)
end_ops = []
for op in total_ops:
if op.prim in end_prims and not any([to_op in total_ops for to_op in op.output.to_ops]):
if op.prim in end_prims and not any((to_op in total_ops for to_op in op.output.to_ops)):
end_ops.append(op)
for start_op in start_ops:
@ -977,21 +978,21 @@ class GraphSplitGpu(GraphSplitByPattern):
for a, _ in dom.out_relations.items():
if _shape_consistent(gather_prims, appected_areas, dom, a) and dom.check_acyclic(a):
return [a], False
return None
return []
def _broadcast_opaque(dom):
"""Fuse rule for TensorScatterAdd and UnsortedSegmentSum."""
def _same_input(op1, op2):
return bool(set(op1.inputs.copy()) & set(op2.inputs.copy()))
return bool(set(op1.inputs) & set(op2.inputs))
if len(dom.ops) != 1:
return None
return []
# Only fuse the first input for `TensorScatterAdd`` and the first and second input for `UnsortedSegmentSum`.
fuse_arg = {"TensorScatterAdd": slice(1, None), "UnsortedSegmentSum": slice(0, 2)}
arg_idx = fuse_arg.get(dom.dom_op().prim, -1)
if arg_idx == -1:
return None
return []
fuse_tensor = dom.dom_op().inputs[arg_idx]
for a, _ in dom.in_relations.items():
@ -1000,21 +1001,21 @@ class GraphSplitGpu(GraphSplitByPattern):
return [a], True
# Rule 2: Fuse op(reshape/elementwise/broadcast) in specified position inputs.
if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \
any([op.output in fuse_tensor for op in a.ops]):
any((op.output in fuse_tensor for op in a.ops)):
return [a], True
return None
return []
def _h_broadcast(dom, a):
if dom.pattern > PrimLib.BROADCAST:
return None
return []
return a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape
def _h_reduce(dom, a):
if dom.pattern != PrimLib.REDUCE or dom.stitch_info.stitch_ops:
return None
return []
dom_op = dom.ops[0]
if not PrimLib.is_reduce(dom_op) or _is_atomic_add_available(dom):
return None
return []
op = a.ops[0]
return a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \
PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \
@ -1022,7 +1023,7 @@ class GraphSplitGpu(GraphSplitByPattern):
def _h_opaque(dom, a):
if dom.ops[0].prim not in {"StridedSlice"}:
return None
return []
return a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \
dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape
@ -1103,20 +1104,20 @@ class GraphSplitAscend(GraphSplitByPattern):
def _broadcast_depth(dom):
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1:
return None
return []
a, r = list(dom.out_relations.items())[0]
if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1:
return None
return []
return [a], False
def _broadcast_width(dom):
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
return None
return []
fused = []
for a, r in dom.out_relations.items():
if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a) or \
(fused and fused[0].dom_op().output.shape != a.dom_op().output.shape):
return None
return []
fused.append(a)
return fused, False
@ -1130,15 +1131,15 @@ class GraphSplitAscend(GraphSplitByPattern):
def _reduce_depth(dom):
if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1:
return None
return []
a, r = list(dom.in_relations.items())[0]
if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1:
return None
return []
return [a], True
def _reduce_width(dom):
if dom.pattern != PrimLib.REDUCE:
return None
return []
fused = []
for a, r in dom.in_relations.items():
if not _reduce_pat_exclude(dom, a, r) and a.check_acyclic(dom):
@ -1147,7 +1148,7 @@ class GraphSplitAscend(GraphSplitByPattern):
def _matmul_depth(dom):
if dom.dom_op().prim != "MatMul" and dom.dom_op().prim != "BatchMatMul":
return None
return []
fused = []
for a, _ in dom.out_relations.items():
if (((a.dom_op().prim == "AddN" or a.dom_op().prim == "Add" or a.dom_op().prim == "Cast")
@ -1159,10 +1160,10 @@ class GraphSplitAscend(GraphSplitByPattern):
def _reduce_output(dom):
if dom.pattern != PrimLib.REDUCE:
return None
return []
op_attrs = dom.dom_op().attrs
if not op_attrs.get('reduce_output_fuse'):
return None
return []
fused = []
for a, r in dom.out_relations.items():
if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \
@ -1172,11 +1173,11 @@ class GraphSplitAscend(GraphSplitByPattern):
def _reduce_stitch(dom):
if dom.pattern != PrimLib.REDUCE:
return None
return []
if tensor_size(dom.ops[0].output) == 1:
return None
return []
if tensor_size(dom.ops[0].inputs[0]) < 32 * 16 * 16:
return None
return []
fused = []
for a, r in dom.out_relations.items():
@ -1233,7 +1234,7 @@ class GraphSplitAscend(GraphSplitByPattern):
def _transdata(dom):
if dom.dom_op().prim != "TransData":
return None
return []
fused = []
for a, _ in dom.in_relations.items():
if _transdata_pattern_support(dom, a) and a.check_acyclic(dom):
@ -1270,99 +1271,11 @@ class GraphSplitAscend(GraphSplitByPattern):
_fuse_once(fuse_func)
class GraphSplitCpu(GraphSplitByPattern):
"""Graph splitter"""
BROADCAST_FUSE_DEPTH = 20
REDUCE_FUSE_DEPTH = 20
def get_default_mode(self, op):
"""Get default mode in CPU"""
del op
return self.Area.MODE_COMPOSITE
def pattern_fuse(self, fuse_func=None):
"""fuse Areas by pattern"""
def _broadcast_pat_exclude(dom, a, r):
if a.pattern == PrimLib.REDUCE:
return dom.pattern > PrimLib.ELEMWISE or r > PrimLib.ELEMWISE
return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST
def _broadcast_depth(dom):
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \
dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
return None
a, r = list(dom.out_relations.items())[0]
if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1:
return None
return [a], False
def _broadcast_width(dom):
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \
dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
return None
fused = []
for a, r in dom.out_relations.items():
if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a) or \
(fused and fused[0].dom_op().output.shape != a.dom_op().output.shape):
return None
fused.append(a)
return fused, False
def _reduce_pat_exclude(_, a, r):
if len(a.ops) > self.REDUCE_FUSE_DEPTH:
return True
return a.pattern > PrimLib.ELEMWISE or r > PrimLib.REDUCE or r == PrimLib.BROADCAST
def _reduce_width(dom):
if dom.pattern != PrimLib.REDUCE:
return None
fused = []
for a, r in dom.in_relations.items():
if not _reduce_pat_exclude(dom, a, r) and a.check_acyclic(dom):
fused.append(a)
return fused, True
def _reduce_depth(dom):
if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1:
return None
a, r = list(dom.in_relations.items())[0]
if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1:
return None
return [a], True
def _fuse_loop():
changed = True
while changed:
changed = False
changed = self.fuse(CommonPattern.reshape) or changed
changed = self.fuse(CommonPattern.assign) or changed
changed = self.fuse(CommonPattern.elemwise_depth) or changed
changed = self.fuse(CommonPattern.elemwise_width) or changed
changed = self.fuse(_reduce_depth) or changed
changed = self.fuse(_reduce_width) or changed
changed = self.fuse(_broadcast_depth) or changed
changed = self.fuse(_broadcast_width) or changed
def _fuse_once(fuse_func):
if fuse_func(CommonPattern.reshape) or fuse_func(CommonPattern.elemwise_depth) or \
fuse_func(CommonPattern.elemwise_width) or fuse_func(_reduce_depth) or \
fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or fuse_func(_broadcast_width):
return
if fuse_func is None:
_fuse_loop()
else:
_fuse_once(fuse_func)
def split(graph, target, flags):
"""Split graph"""
result = None
if target == "cuda":
result = GraphSplitGpu(graph, flags).split()
elif target == "aicore":
result = GraphSplitAscend(graph, flags).split()
else:
result = GraphSplitCpu(graph, flags).split()
result = GraphSplitAscend(graph, flags).split()
return result