forked from mindspore-Ecosystem/mindspore
fix pylint warnings of graph_split.py
the GraphSplitCpu is removed in this commit, the environ var "MS_DEV_GRAPH_KERNEL_PY_SPLIT_MODEL" is removed also.
This commit is contained in:
parent
56183ab741
commit
a3e8f0e3cf
|
@ -232,7 +232,7 @@ class CostModelSplitSchemer : public SplitSchemer {
|
|||
|
||||
std::shared_ptr<SplitSchemer> GraphKernelSplitterWithPy::GetSplitSchema(const std::string &processor) {
|
||||
// default use c++ split model for CPU target.
|
||||
if (processor != kCPUDevice || common::GetEnv("MS_DEV_GRAPH_KERNEL_PY_SPLIT_MODEL") == "on") {
|
||||
if (processor != kCPUDevice) {
|
||||
MS_LOG(DEBUG) << "use py split model";
|
||||
return std::make_shared<CostModelSplitSchemer>();
|
||||
} else {
|
||||
|
|
|
@ -50,10 +50,9 @@ class COMMON_EXPORT GraphKernelFlags {
|
|||
|
||||
GraphKernelFlags(const GraphKernelFlags &flags) = delete;
|
||||
GraphKernelFlags(GraphKernelFlags &&flags) = delete;
|
||||
void operator=(const GraphKernelFlags &flags) = delete;
|
||||
GraphKernelFlags &operator=(const GraphKernelFlags &flags) = delete;
|
||||
~GraphKernelFlags() = default;
|
||||
|
||||
public:
|
||||
/**
|
||||
* Dump info as human-readable text.
|
||||
* A directory "graph_kernel_dump" will be created, and all information will be dumped in this directory.
|
||||
|
|
|
@ -52,7 +52,7 @@ def may_stitch(dom, a, r, stitch_axis_size, stitch_buffer_size):
|
|||
stitchaxis.append(i)
|
||||
if size >= stitch_axis_size:
|
||||
return stitchaxis
|
||||
return None
|
||||
return []
|
||||
|
||||
x = []
|
||||
x.extend(stitch_tensors)
|
||||
|
@ -60,21 +60,21 @@ def may_stitch(dom, a, r, stitch_axis_size, stitch_buffer_size):
|
|||
stitch_axis_0 = _stitch_axis(x[0].shape, stitch_axis_size)
|
||||
for item in x:
|
||||
i_stitch_axis = _stitch_axis(item.shape, stitch_axis_size)
|
||||
if i_stitch_axis is None or i_stitch_axis != stitch_axis_0:
|
||||
if not i_stitch_axis or i_stitch_axis != stitch_axis_0:
|
||||
return False
|
||||
return True
|
||||
|
||||
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
|
||||
if reduce_nums(a.ops) >= 2:
|
||||
return False
|
||||
dom_outs = [op.output for op in dom.ops]
|
||||
a_ins = [op_input for op in a.ops for op_input in op.inputs]
|
||||
a_outs = [op.output for op in a.ops]
|
||||
a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins]
|
||||
stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins]
|
||||
dom_outs = set(op.output for op in dom.ops)
|
||||
a_ins = set(op_input for op in a.ops for op_input in op.inputs)
|
||||
a_outs = set(op.output for op in a.ops)
|
||||
a_final_outs = list(tensor for tensor in a_outs if tensor not in a_ins)
|
||||
stitch_tensors = list(tensor for tensor in dom_outs if tensor in a_ins)
|
||||
if not _same_stitch_axis(stitch_tensors, a_final_outs, stitch_axis_size):
|
||||
return False
|
||||
return any([tensor_size(tensor) >= stitch_buffer_size for tensor in stitch_tensors])
|
||||
return any((tensor_size(tensor) >= stitch_buffer_size for tensor in stitch_tensors))
|
||||
return False
|
||||
|
||||
|
||||
|
@ -85,7 +85,7 @@ class CommonPattern:
|
|||
def reshape(dom):
|
||||
"""fuse strategy for reshape dom"""
|
||||
if dom.pattern != PrimLib.RESHAPE:
|
||||
return None
|
||||
return []
|
||||
min_area, forward_fuse = None, False
|
||||
for a, _ in dom.out_relations.items():
|
||||
if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \
|
||||
|
@ -96,24 +96,24 @@ class CommonPattern:
|
|||
len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \
|
||||
(min_area is None or a.pattern < min_area.pattern):
|
||||
min_area, forward_fuse = a, True
|
||||
return ([min_area], forward_fuse) if min_area else None
|
||||
return ([min_area], forward_fuse) if min_area else []
|
||||
|
||||
@staticmethod
|
||||
def elemwise_depth(dom):
|
||||
"""fuse strategy in depth for elemwise dom"""
|
||||
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1:
|
||||
return None
|
||||
return []
|
||||
a, r = list(dom.in_relations.items())[0]
|
||||
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r != PrimLib.ELEMWISE or \
|
||||
a.dom_op().output.shape != dom.dom_op().output.shape:
|
||||
return None
|
||||
return []
|
||||
return [a], True
|
||||
|
||||
@staticmethod
|
||||
def elemwise_width(dom):
|
||||
"""fuse strategy in width for elemwise dom"""
|
||||
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
|
||||
return None
|
||||
return []
|
||||
fused = []
|
||||
for a, r in dom.in_relations.items():
|
||||
if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_acyclic(dom) and \
|
||||
|
@ -125,7 +125,7 @@ class CommonPattern:
|
|||
def assign(dom):
|
||||
"""fuse strategy for assign dom"""
|
||||
if len(dom.ops) != 1 or dom.dom_op().prim != "Assign":
|
||||
return None
|
||||
return []
|
||||
fused = []
|
||||
for a, _ in dom.in_relations.items():
|
||||
fused.append(a)
|
||||
|
@ -142,7 +142,7 @@ class GraphSplitByPattern:
|
|||
self.map = []
|
||||
self.alive = set(range(size))
|
||||
for i in range(0, size):
|
||||
self.map.append([False for j in range(0, size)])
|
||||
self.map.append([False] * size)
|
||||
self.map[i][i] = True
|
||||
|
||||
def reachable(self, x, y):
|
||||
|
@ -214,12 +214,13 @@ class GraphSplitByPattern:
|
|||
self.reach_tab = reach_tab
|
||||
|
||||
def __str__(self):
|
||||
return '<' + '-'.join([op.output.name for op in self.ops]) + '>'
|
||||
return '<' + '-'.join((op.output.name for op in self.ops)) + '>'
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def get_relation(self, op, i):
|
||||
@staticmethod
|
||||
def get_relation(op, i):
|
||||
"""Get op relation"""
|
||||
relation = PrimLib.UNKNOWN
|
||||
_, elem_relation = PrimLib.input_relation(op, i)
|
||||
|
@ -260,8 +261,8 @@ class GraphSplitByPattern:
|
|||
def _update_pattern():
|
||||
if area.pattern > self.pattern:
|
||||
self.pattern = area.pattern
|
||||
if area in self.in_relations and self.in_relations[area] > self.pattern:
|
||||
self.pattern = self.in_relations[area]
|
||||
if area in self.in_relations and self.in_relations.get(area) > self.pattern:
|
||||
self.pattern = self.in_relations.get(area)
|
||||
|
||||
def _fuse_relation(self_relations, new_relations):
|
||||
for a, r in new_relations.items():
|
||||
|
@ -333,8 +334,8 @@ class GraphSplitByPattern:
|
|||
# copy ops
|
||||
cp_ops = []
|
||||
for op in area.recompute_ops:
|
||||
inputs = [tensor_map[op.inputs[0]]] if op.inputs else []
|
||||
cp_op = Operator(op.prim, inputs, tensor_map[op.output], op.attrs)
|
||||
inputs = [tensor_map.get(op.inputs[0])] if op.inputs else []
|
||||
cp_op = Operator(op.prim, inputs, tensor_map.get(op.output), op.attrs)
|
||||
cp_op.all_inputs = cp_op.inputs
|
||||
cp_ops.append(cp_op)
|
||||
area.ori_op_map[cp_op] = op
|
||||
|
@ -342,9 +343,9 @@ class GraphSplitByPattern:
|
|||
for op in self.ops:
|
||||
if tail_tensor in op.inputs:
|
||||
op.inputs.remove(tail_tensor)
|
||||
op.inputs.append(tensor_map[tail_tensor])
|
||||
op.inputs.append(tensor_map.get(tail_tensor))
|
||||
tail_tensor.to_ops.remove(op)
|
||||
tensor_map[tail_tensor].to_ops.append(op)
|
||||
tensor_map.get(tail_tensor).to_ops.append(op)
|
||||
# fill cp_ops in self.recompute_area
|
||||
cp_dom_op = None
|
||||
for cp, ori in area.ori_op_map.items():
|
||||
|
@ -352,7 +353,7 @@ class GraphSplitByPattern:
|
|||
cp_dom_op = cp
|
||||
area.ops.clear()
|
||||
area.ops.append(cp_dom_op)
|
||||
area.ops.extend([op for op in cp_ops if op != cp_dom_op])
|
||||
area.ops.extend((op for op in cp_ops if op != cp_dom_op))
|
||||
|
||||
def __init__(self, graph, flags):
|
||||
self.graph = graph
|
||||
|
@ -395,7 +396,8 @@ class GraphSplitByPattern:
|
|||
"""Set default mode"""
|
||||
area.mode = self.get_default_mode(area.ops[0])
|
||||
|
||||
def limit_area_size(self, dominant, fuse_areas, limit_size=200):
|
||||
@staticmethod
|
||||
def limit_area_size(dominant, fuse_areas, limit_size=200):
|
||||
"""Remove some areas if the size is too large"""
|
||||
area_sizes = map(lambda area: len(area.ops), fuse_areas)
|
||||
dom_size = len(dominant.ops)
|
||||
|
@ -417,7 +419,7 @@ class GraphSplitByPattern:
|
|||
def _fuse_area():
|
||||
for dominant in self.areas:
|
||||
result = selector(dominant)
|
||||
if result is None or not result[0]:
|
||||
if not result or not result[0]:
|
||||
continue
|
||||
fuse_areas, is_forward = result
|
||||
fuse_areas = self.limit_area_size(dominant, fuse_areas)
|
||||
|
@ -460,7 +462,7 @@ class GraphSplitByPattern:
|
|||
|
||||
def _update_areas(areas, from_op):
|
||||
for op in from_op.to_ops:
|
||||
a = self.area_map[op]
|
||||
a = self.area_map.get(op)
|
||||
if a in self.areas and a not in areas:
|
||||
areas.append(a)
|
||||
|
||||
|
@ -488,7 +490,7 @@ class GraphSplitByPattern:
|
|||
"""Fuse recompute area to its user"""
|
||||
for dominant in [self.recom_area, self.recom_user]:
|
||||
result = selector(dominant)
|
||||
if result is not None and result[0]:
|
||||
if result and result[0]:
|
||||
fuse_areas, _ = result
|
||||
fuse_areas = self.limit_area_size(dominant, fuse_areas)
|
||||
if not fuse_areas:
|
||||
|
@ -507,7 +509,7 @@ class GraphSplitByPattern:
|
|||
ids[op] = i
|
||||
if hasattr(self, 'orig_op_map'):
|
||||
for k, v in self.orig_op_map.items():
|
||||
ids[k] = ids[v]
|
||||
ids[k] = ids.get(v)
|
||||
return ids
|
||||
|
||||
def to_subgraphs(self):
|
||||
|
@ -516,7 +518,7 @@ class GraphSplitByPattern:
|
|||
subgraphs = []
|
||||
graphmodes = []
|
||||
for i, area in enumerate(self.areas):
|
||||
area.ops.sort(key=lambda op: ids[op])
|
||||
area.ops.sort(key=ids.get)
|
||||
subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops, area.stitch_info, area.recompute_ops))
|
||||
graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite")
|
||||
return subgraphs, graphmodes
|
||||
|
@ -542,7 +544,7 @@ class GraphSplitByPattern:
|
|||
def _remove_output_reshape(reshape_ops, other_ops):
|
||||
def _run():
|
||||
for op in reshape_ops:
|
||||
if any([to_op in other_ops for to_op in op.output.to_ops]):
|
||||
if any((to_op in other_ops for to_op in op.output.to_ops)):
|
||||
reshape_ops.remove(op)
|
||||
other_ops.append(op)
|
||||
return True
|
||||
|
@ -553,8 +555,8 @@ class GraphSplitByPattern:
|
|||
|
||||
new_areas = []
|
||||
for area in self.areas:
|
||||
reshape_ops = [op for op in area.ops if PrimLib.iter_type(op) == PrimLib.RESHAPE]
|
||||
other_ops = [op for op in area.ops if op not in reshape_ops]
|
||||
reshape_ops = list(op for op in area.ops if PrimLib.iter_type(op) == PrimLib.RESHAPE)
|
||||
other_ops = list(op for op in area.ops if op not in reshape_ops)
|
||||
if not other_ops or not reshape_ops:
|
||||
continue
|
||||
# remove the output reshape from "reshape_ops" and add it into "other_ops"
|
||||
|
@ -575,7 +577,7 @@ class GraphSplitByPattern:
|
|||
"""set the recompute area and connect with other areas"""
|
||||
self.recom_area.recompute_ops.extend(ops)
|
||||
# recom_area: set dom_op and correct ops length
|
||||
patterns = [PrimLib.iter_type(op) for op in ops]
|
||||
patterns = list(PrimLib.iter_type(op) for op in ops)
|
||||
self.recom_area.pattern = max(patterns)
|
||||
for i, pat in enumerate(patterns):
|
||||
if pat == self.recom_area.pattern:
|
||||
|
@ -589,7 +591,7 @@ class GraphSplitByPattern:
|
|||
user_area.in_relations[self.recom_area] = self.dom_user_r
|
||||
self.recom_area.out_relations[user_area] = self.dom_user_r
|
||||
# connect recom_pre and recom_area
|
||||
self.recom_pre = self.area_map[ops[0].inputs[0].op] if ops[0].inputs and ops[0].inputs[0].op else None
|
||||
self.recom_pre = self.area_map.get(ops[0].inputs[0].op) if ops[0].inputs and ops[0].inputs[0].op else None
|
||||
if self.recom_pre is not None:
|
||||
self.recom_area.in_relations[self.recom_pre] = dom_area.in_relations[self.recom_pre]
|
||||
self.recom_pre.out_relations[self.recom_area] = dom_area.in_relations[self.recom_pre]
|
||||
|
@ -618,7 +620,7 @@ class GraphSplitByPattern:
|
|||
ids = self.index_op()
|
||||
dom_ops = list()
|
||||
dom_ops.extend(dom.ops)
|
||||
dom_ops.sort(key=lambda op: ids[op])
|
||||
dom_ops.sort(key=ids.get)
|
||||
subgraph = []
|
||||
subgraph = Graph('{}_area'.format(self.graph.name), dom_ops)
|
||||
return subgraph
|
||||
|
@ -631,15 +633,15 @@ class GraphSplitByPattern:
|
|||
# region successfully ends at inputs
|
||||
if not op.inputs:
|
||||
region_ops.append(op)
|
||||
return False, None, weight, True
|
||||
return False, op, weight, True
|
||||
if op.inputs[0] in inputs and len(op.inputs) == 1 and \
|
||||
PrimLib.iter_type(op) <= PrimLib.BROADCAST:
|
||||
region_ops.append(op)
|
||||
return False, None, weight, True
|
||||
return False, op, weight, True
|
||||
# region fails to grow
|
||||
max_weight = 20
|
||||
if weight > max_weight or len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST:
|
||||
return False, None, weight, False
|
||||
return False, op, weight, False
|
||||
# region grows successfully
|
||||
weight = weight + 1
|
||||
region_ops.append(op)
|
||||
|
@ -681,16 +683,15 @@ class GraphSplitByPattern:
|
|||
_, dom_outputs = dom_graph.deduce_parameters()
|
||||
user_graph = self.to_subgraph(user_area)
|
||||
user_inputs, _ = user_graph.deduce_parameters()
|
||||
edge = [t for t in dom_outputs if t in user_inputs]
|
||||
return len(edge)
|
||||
return len(list(t for t in dom_outputs if t in user_inputs))
|
||||
|
||||
def _select_user_area(tail_tensor):
|
||||
user_areas = []
|
||||
for user_op in tail_tensor.to_ops:
|
||||
user_area = self.area_map[user_op]
|
||||
user_area = self.area_map.get(user_op)
|
||||
if user_area.pattern == PrimLib.RESHAPE:
|
||||
continue
|
||||
edge_num = _get_edge_num(self.area_map[tail_tensor.op], user_area)
|
||||
edge_num = _get_edge_num(self.area_map.get(tail_tensor.op), user_area)
|
||||
if edge_num == 1 and not user_area in user_areas:
|
||||
user_areas.append(user_area)
|
||||
return user_areas
|
||||
|
@ -757,21 +758,21 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
def _broadcast_depth(dom):
|
||||
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \
|
||||
dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
|
||||
return None
|
||||
return []
|
||||
a, r = list(dom.out_relations.items())[0]
|
||||
if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1:
|
||||
return None
|
||||
return []
|
||||
return [a], False
|
||||
|
||||
def _broadcast_width(dom):
|
||||
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \
|
||||
dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
|
||||
return None
|
||||
return []
|
||||
fused = []
|
||||
for a, r in dom.out_relations.items():
|
||||
if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a) or \
|
||||
(fused and fused[0].dom_op().output.shape != a.dom_op().output.shape):
|
||||
return None
|
||||
return []
|
||||
fused.append(a)
|
||||
return fused, False
|
||||
|
||||
|
@ -782,19 +783,19 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
|
||||
def _reduce_depth(dom):
|
||||
if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1:
|
||||
return None
|
||||
return []
|
||||
a, r = list(dom.in_relations.items())[0]
|
||||
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
|
||||
_is_atomic_add_available(dom):
|
||||
# to evade the precision problem.
|
||||
return None
|
||||
return []
|
||||
if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1:
|
||||
return None
|
||||
return []
|
||||
return [a], True
|
||||
|
||||
def _reduce_width(dom):
|
||||
if dom.pattern != PrimLib.REDUCE:
|
||||
return None
|
||||
return []
|
||||
fused = []
|
||||
for a, r in dom.in_relations.items():
|
||||
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
|
||||
|
@ -806,7 +807,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
return fused, True
|
||||
|
||||
def _is_atomic_add_available(dom):
|
||||
if any(["Reduce" in x.prim for x in dom.ops[1:]]):
|
||||
if any(("Reduce" in x.prim for x in dom.ops[1:])):
|
||||
return False
|
||||
op = dom.ops[0]
|
||||
if "reduce_axis" in op.attrs:
|
||||
|
@ -816,21 +817,21 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
else:
|
||||
raise Exception("For '{}', can not find the attr 'reduce_axis' or 'axis'".format(op.prim))
|
||||
if len(op.inputs[0].shape) - 1 in reduce_axis:
|
||||
reduce_size = prod_reduce(lambda x, y: x * y, [op.inputs[0].shape[i] for i in reduce_axis])
|
||||
reduce_size = prod_reduce(lambda x, y: x * y, (op.inputs[0].shape[i] for i in reduce_axis))
|
||||
return reduce_size >= 1024
|
||||
return True
|
||||
|
||||
def _reduce_output(dom):
|
||||
if dom.pattern != PrimLib.REDUCE:
|
||||
return None
|
||||
return []
|
||||
if reduce_nums(dom.ops) > 1:
|
||||
return None
|
||||
return []
|
||||
if _is_atomic_add_available(dom):
|
||||
return None
|
||||
return []
|
||||
is_all_reduce = tensor_size(dom.ops[0].output) == 1
|
||||
# excluded large size all reduce
|
||||
if is_all_reduce and tensor_size(dom.ops[0].inputs[0]) > 1024 * 12:
|
||||
return None
|
||||
return []
|
||||
|
||||
fused = []
|
||||
for a, r in dom.out_relations.items():
|
||||
|
@ -841,11 +842,11 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
|
||||
def _reduce_stitch(dom):
|
||||
if dom.pattern != PrimLib.REDUCE:
|
||||
return None
|
||||
return []
|
||||
if tensor_size(dom.ops[0].output) == 1:
|
||||
return None
|
||||
return []
|
||||
if tensor_size(dom.ops[0].inputs[0]) < 1024 * 12:
|
||||
return None
|
||||
return []
|
||||
|
||||
fused = []
|
||||
for a, r in dom.out_relations.items():
|
||||
|
@ -862,7 +863,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
|
||||
def _transpose(dom):
|
||||
if len(dom.ops) != 1 or dom.ops[0].prim != "Transpose":
|
||||
return None
|
||||
return []
|
||||
fused = []
|
||||
for a, _ in dom.in_relations.items():
|
||||
if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and len(a.ops) <= self.TRANSPOSE_FUSE_DEPTH:
|
||||
|
@ -871,7 +872,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
|
||||
def _strided_slice(dom):
|
||||
if dom.dom_op().prim != "StridedSlice":
|
||||
return None
|
||||
return []
|
||||
fused = []
|
||||
for a, _ in dom.in_relations.items():
|
||||
if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \
|
||||
|
@ -882,7 +883,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
def _gather_output(dom, reduce_fusion=False):
|
||||
gather_prims = ("Gather", "GatherNd")
|
||||
if not dom.dom_op().prim in gather_prims:
|
||||
return None
|
||||
return []
|
||||
|
||||
def _reduce_exclude(op, axis_list):
|
||||
""" Whether this operator should be excluded.
|
||||
|
@ -956,7 +957,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
start_ops.append(op)
|
||||
end_ops = []
|
||||
for op in total_ops:
|
||||
if op.prim in end_prims and not any([to_op in total_ops for to_op in op.output.to_ops]):
|
||||
if op.prim in end_prims and not any((to_op in total_ops for to_op in op.output.to_ops)):
|
||||
end_ops.append(op)
|
||||
|
||||
for start_op in start_ops:
|
||||
|
@ -977,21 +978,21 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
for a, _ in dom.out_relations.items():
|
||||
if _shape_consistent(gather_prims, appected_areas, dom, a) and dom.check_acyclic(a):
|
||||
return [a], False
|
||||
return None
|
||||
return []
|
||||
|
||||
def _broadcast_opaque(dom):
|
||||
"""Fuse rule for TensorScatterAdd and UnsortedSegmentSum."""
|
||||
def _same_input(op1, op2):
|
||||
return bool(set(op1.inputs.copy()) & set(op2.inputs.copy()))
|
||||
return bool(set(op1.inputs) & set(op2.inputs))
|
||||
|
||||
if len(dom.ops) != 1:
|
||||
return None
|
||||
return []
|
||||
|
||||
# Only fuse the first input for `TensorScatterAdd`` and the first and second input for `UnsortedSegmentSum`.
|
||||
fuse_arg = {"TensorScatterAdd": slice(1, None), "UnsortedSegmentSum": slice(0, 2)}
|
||||
arg_idx = fuse_arg.get(dom.dom_op().prim, -1)
|
||||
if arg_idx == -1:
|
||||
return None
|
||||
return []
|
||||
fuse_tensor = dom.dom_op().inputs[arg_idx]
|
||||
|
||||
for a, _ in dom.in_relations.items():
|
||||
|
@ -1000,21 +1001,21 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
return [a], True
|
||||
# Rule 2: Fuse op(reshape/elementwise/broadcast) in specified position inputs.
|
||||
if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \
|
||||
any([op.output in fuse_tensor for op in a.ops]):
|
||||
any((op.output in fuse_tensor for op in a.ops)):
|
||||
return [a], True
|
||||
return None
|
||||
return []
|
||||
|
||||
def _h_broadcast(dom, a):
|
||||
if dom.pattern > PrimLib.BROADCAST:
|
||||
return None
|
||||
return []
|
||||
return a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape
|
||||
|
||||
def _h_reduce(dom, a):
|
||||
if dom.pattern != PrimLib.REDUCE or dom.stitch_info.stitch_ops:
|
||||
return None
|
||||
return []
|
||||
dom_op = dom.ops[0]
|
||||
if not PrimLib.is_reduce(dom_op) or _is_atomic_add_available(dom):
|
||||
return None
|
||||
return []
|
||||
op = a.ops[0]
|
||||
return a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \
|
||||
PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \
|
||||
|
@ -1022,7 +1023,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
|
||||
def _h_opaque(dom, a):
|
||||
if dom.ops[0].prim not in {"StridedSlice"}:
|
||||
return None
|
||||
return []
|
||||
return a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \
|
||||
dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape
|
||||
|
||||
|
@ -1103,20 +1104,20 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
|
||||
def _broadcast_depth(dom):
|
||||
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1:
|
||||
return None
|
||||
return []
|
||||
a, r = list(dom.out_relations.items())[0]
|
||||
if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1:
|
||||
return None
|
||||
return []
|
||||
return [a], False
|
||||
|
||||
def _broadcast_width(dom):
|
||||
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
|
||||
return None
|
||||
return []
|
||||
fused = []
|
||||
for a, r in dom.out_relations.items():
|
||||
if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a) or \
|
||||
(fused and fused[0].dom_op().output.shape != a.dom_op().output.shape):
|
||||
return None
|
||||
return []
|
||||
fused.append(a)
|
||||
return fused, False
|
||||
|
||||
|
@ -1130,15 +1131,15 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
|
||||
def _reduce_depth(dom):
|
||||
if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1:
|
||||
return None
|
||||
return []
|
||||
a, r = list(dom.in_relations.items())[0]
|
||||
if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1:
|
||||
return None
|
||||
return []
|
||||
return [a], True
|
||||
|
||||
def _reduce_width(dom):
|
||||
if dom.pattern != PrimLib.REDUCE:
|
||||
return None
|
||||
return []
|
||||
fused = []
|
||||
for a, r in dom.in_relations.items():
|
||||
if not _reduce_pat_exclude(dom, a, r) and a.check_acyclic(dom):
|
||||
|
@ -1147,7 +1148,7 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
|
||||
def _matmul_depth(dom):
|
||||
if dom.dom_op().prim != "MatMul" and dom.dom_op().prim != "BatchMatMul":
|
||||
return None
|
||||
return []
|
||||
fused = []
|
||||
for a, _ in dom.out_relations.items():
|
||||
if (((a.dom_op().prim == "AddN" or a.dom_op().prim == "Add" or a.dom_op().prim == "Cast")
|
||||
|
@ -1159,10 +1160,10 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
|
||||
def _reduce_output(dom):
|
||||
if dom.pattern != PrimLib.REDUCE:
|
||||
return None
|
||||
return []
|
||||
op_attrs = dom.dom_op().attrs
|
||||
if not op_attrs.get('reduce_output_fuse'):
|
||||
return None
|
||||
return []
|
||||
fused = []
|
||||
for a, r in dom.out_relations.items():
|
||||
if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \
|
||||
|
@ -1172,11 +1173,11 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
|
||||
def _reduce_stitch(dom):
|
||||
if dom.pattern != PrimLib.REDUCE:
|
||||
return None
|
||||
return []
|
||||
if tensor_size(dom.ops[0].output) == 1:
|
||||
return None
|
||||
return []
|
||||
if tensor_size(dom.ops[0].inputs[0]) < 32 * 16 * 16:
|
||||
return None
|
||||
return []
|
||||
|
||||
fused = []
|
||||
for a, r in dom.out_relations.items():
|
||||
|
@ -1233,7 +1234,7 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
|
||||
def _transdata(dom):
|
||||
if dom.dom_op().prim != "TransData":
|
||||
return None
|
||||
return []
|
||||
fused = []
|
||||
for a, _ in dom.in_relations.items():
|
||||
if _transdata_pattern_support(dom, a) and a.check_acyclic(dom):
|
||||
|
@ -1270,99 +1271,11 @@ class GraphSplitAscend(GraphSplitByPattern):
|
|||
_fuse_once(fuse_func)
|
||||
|
||||
|
||||
class GraphSplitCpu(GraphSplitByPattern):
|
||||
"""Graph splitter"""
|
||||
BROADCAST_FUSE_DEPTH = 20
|
||||
REDUCE_FUSE_DEPTH = 20
|
||||
|
||||
def get_default_mode(self, op):
|
||||
"""Get default mode in CPU"""
|
||||
del op
|
||||
return self.Area.MODE_COMPOSITE
|
||||
|
||||
def pattern_fuse(self, fuse_func=None):
|
||||
"""fuse Areas by pattern"""
|
||||
|
||||
def _broadcast_pat_exclude(dom, a, r):
|
||||
if a.pattern == PrimLib.REDUCE:
|
||||
return dom.pattern > PrimLib.ELEMWISE or r > PrimLib.ELEMWISE
|
||||
return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST
|
||||
|
||||
def _broadcast_depth(dom):
|
||||
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \
|
||||
dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
|
||||
return None
|
||||
a, r = list(dom.out_relations.items())[0]
|
||||
if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1:
|
||||
return None
|
||||
return [a], False
|
||||
|
||||
def _broadcast_width(dom):
|
||||
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \
|
||||
dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
|
||||
return None
|
||||
fused = []
|
||||
for a, r in dom.out_relations.items():
|
||||
if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a) or \
|
||||
(fused and fused[0].dom_op().output.shape != a.dom_op().output.shape):
|
||||
return None
|
||||
fused.append(a)
|
||||
return fused, False
|
||||
|
||||
def _reduce_pat_exclude(_, a, r):
|
||||
if len(a.ops) > self.REDUCE_FUSE_DEPTH:
|
||||
return True
|
||||
return a.pattern > PrimLib.ELEMWISE or r > PrimLib.REDUCE or r == PrimLib.BROADCAST
|
||||
|
||||
def _reduce_width(dom):
|
||||
if dom.pattern != PrimLib.REDUCE:
|
||||
return None
|
||||
fused = []
|
||||
for a, r in dom.in_relations.items():
|
||||
if not _reduce_pat_exclude(dom, a, r) and a.check_acyclic(dom):
|
||||
fused.append(a)
|
||||
return fused, True
|
||||
|
||||
def _reduce_depth(dom):
|
||||
if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1:
|
||||
return None
|
||||
a, r = list(dom.in_relations.items())[0]
|
||||
if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1:
|
||||
return None
|
||||
return [a], True
|
||||
|
||||
def _fuse_loop():
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
changed = self.fuse(CommonPattern.reshape) or changed
|
||||
changed = self.fuse(CommonPattern.assign) or changed
|
||||
changed = self.fuse(CommonPattern.elemwise_depth) or changed
|
||||
changed = self.fuse(CommonPattern.elemwise_width) or changed
|
||||
changed = self.fuse(_reduce_depth) or changed
|
||||
changed = self.fuse(_reduce_width) or changed
|
||||
changed = self.fuse(_broadcast_depth) or changed
|
||||
changed = self.fuse(_broadcast_width) or changed
|
||||
|
||||
def _fuse_once(fuse_func):
|
||||
if fuse_func(CommonPattern.reshape) or fuse_func(CommonPattern.elemwise_depth) or \
|
||||
fuse_func(CommonPattern.elemwise_width) or fuse_func(_reduce_depth) or \
|
||||
fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or fuse_func(_broadcast_width):
|
||||
return
|
||||
|
||||
if fuse_func is None:
|
||||
_fuse_loop()
|
||||
else:
|
||||
_fuse_once(fuse_func)
|
||||
|
||||
|
||||
def split(graph, target, flags):
|
||||
"""Split graph"""
|
||||
result = None
|
||||
if target == "cuda":
|
||||
result = GraphSplitGpu(graph, flags).split()
|
||||
elif target == "aicore":
|
||||
result = GraphSplitAscend(graph, flags).split()
|
||||
else:
|
||||
result = GraphSplitCpu(graph, flags).split()
|
||||
result = GraphSplitAscend(graph, flags).split()
|
||||
return result
|
||||
|
|
Loading…
Reference in New Issue