consider no tensor input in graph

This commit is contained in:
lingyunli63 2021-07-19 19:57:35 +08:00
parent dda605cd8f
commit 32571f5a0c
1 changed files with 16 additions and 11 deletions

View File

@ -192,7 +192,8 @@ class GraphSplitByPattern:
tail_tensor = area.recompute_ops[-1].output
#copy tensors, all copied are Tensor.PARA_NONE
tensor_map = {}
tensor_map[area.recompute_ops[0].inputs[0]] = area.recompute_ops[0].inputs[0]
if area.recompute_ops[0].inputs:
tensor_map[area.recompute_ops[0].inputs[0]] = area.recompute_ops[0].inputs[0]
for op in area.recompute_ops:
orig_tensor = op.output
cp_tensor = Tensor(orig_tensor.name, orig_tensor.shape, orig_tensor.dtype, orig_tensor.data_format)
@ -200,7 +201,8 @@ class GraphSplitByPattern:
#copy ops
cp_ops = []
for op in area.recompute_ops:
cp_op = Operator(op.prim, [tensor_map[op.inputs[0]]], tensor_map[op.output], op.attrs)
inputs = [tensor_map[op.inputs[0]]] if op.inputs else []
cp_op = Operator(op.prim, inputs, tensor_map[op.output], op.attrs)
cp_op.all_inputs = cp_op.inputs
cp_ops.append(cp_op)
area.ori_op_map[cp_op] = op
@ -415,7 +417,7 @@ class GraphSplitByPattern:
user_area.in_relations[self.recom_area] = self.dom_user_r
self.recom_area.out_relations[user_area] = self.dom_user_r
#connect recom_pre and recom_area
self.recom_pre = self.area_map[ops[0].inputs[0].op] if ops[0].inputs[0].op else None
self.recom_pre = self.area_map[ops[0].inputs[0].op] if ops[0].inputs and ops[0].inputs[0].op else None
if self.recom_pre is not None:
self.recom_area.in_relations[self.recom_pre] = dom_area.in_relations[self.recom_pre]
self.recom_pre.out_relations[self.recom_area] = dom_area.in_relations[self.recom_pre]
@ -454,19 +456,22 @@ class GraphSplitByPattern:
"""extract all the cheap regions in dom area, toposort each region before return"""
def _grow_region(region_ops, op, weight, inputs):
"""include op to region_ops if region grow"""
# region successfully ends at input
# region successfully ends at inputs
if not op.inputs:
region_ops.append(op)
return False, None, weight, True
if op.inputs[0] in inputs and len(op.inputs) == 1 and \
PrimLib.iter_type(op) <= PrimLib.BROADCAST:
region_ops.append(op)
return False, None, weight
return False, None, weight, True
#region fails to grow
MAX_WEIGHT = 20
if weight > MAX_WEIGHT or len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST:
return False, None, weight
return False, None, weight, False
#region grows successfully
weight = weight + 1
region_ops.append(op)
return True, op.inputs[0].op, weight
return True, op.inputs[0].op, weight, False
def _find_cheap_regions(dom):
sub = self.to_subgraph(dom)
@ -480,13 +485,13 @@ class GraphSplitByPattern:
grow = True
candidate_op = output.op
weight = 1
result = False
while grow:
grow, candidate_op, weight = _grow_region(region_ops, candidate_op, weight, inputs)
# region ends at input and not empty
if region_ops and region_ops[-1].inputs[0] in inputs:
grow, candidate_op, weight, result = _grow_region(region_ops, candidate_op, weight, inputs)
if result:
region_ops.reverse()
# tensor size should equal or becomes larger(cast up, broadcast)
if region_ops[0].inputs[0].get_size() > region_ops[-1].output.get_size():
if region_ops[0].inputs and region_ops[0].inputs[0].get_size() > region_ops[-1].output.get_size():
continue
cheap_regions.append(region_ops)
return cheap_regions