consider no tensor input in graph
This commit is contained in:
parent
dda605cd8f
commit
32571f5a0c
|
@ -192,7 +192,8 @@ class GraphSplitByPattern:
|
|||
tail_tensor = area.recompute_ops[-1].output
|
||||
#copy tensors, all copied are Tensor.PARA_NONE
|
||||
tensor_map = {}
|
||||
tensor_map[area.recompute_ops[0].inputs[0]] = area.recompute_ops[0].inputs[0]
|
||||
if area.recompute_ops[0].inputs:
|
||||
tensor_map[area.recompute_ops[0].inputs[0]] = area.recompute_ops[0].inputs[0]
|
||||
for op in area.recompute_ops:
|
||||
orig_tensor = op.output
|
||||
cp_tensor = Tensor(orig_tensor.name, orig_tensor.shape, orig_tensor.dtype, orig_tensor.data_format)
|
||||
|
@ -200,7 +201,8 @@ class GraphSplitByPattern:
|
|||
#copy ops
|
||||
cp_ops = []
|
||||
for op in area.recompute_ops:
|
||||
cp_op = Operator(op.prim, [tensor_map[op.inputs[0]]], tensor_map[op.output], op.attrs)
|
||||
inputs = [tensor_map[op.inputs[0]]] if op.inputs else []
|
||||
cp_op = Operator(op.prim, inputs, tensor_map[op.output], op.attrs)
|
||||
cp_op.all_inputs = cp_op.inputs
|
||||
cp_ops.append(cp_op)
|
||||
area.ori_op_map[cp_op] = op
|
||||
|
@ -415,7 +417,7 @@ class GraphSplitByPattern:
|
|||
user_area.in_relations[self.recom_area] = self.dom_user_r
|
||||
self.recom_area.out_relations[user_area] = self.dom_user_r
|
||||
#connect recom_pre and recom_area
|
||||
self.recom_pre = self.area_map[ops[0].inputs[0].op] if ops[0].inputs[0].op else None
|
||||
self.recom_pre = self.area_map[ops[0].inputs[0].op] if ops[0].inputs and ops[0].inputs[0].op else None
|
||||
if self.recom_pre is not None:
|
||||
self.recom_area.in_relations[self.recom_pre] = dom_area.in_relations[self.recom_pre]
|
||||
self.recom_pre.out_relations[self.recom_area] = dom_area.in_relations[self.recom_pre]
|
||||
|
@ -454,19 +456,22 @@ class GraphSplitByPattern:
|
|||
"""extract all the cheap regions in dom area, toposort each region before return"""
|
||||
def _grow_region(region_ops, op, weight, inputs):
|
||||
"""include op to region_ops if region grow"""
|
||||
# region successfully ends at input
|
||||
# region successfully ends at inputs
|
||||
if not op.inputs:
|
||||
region_ops.append(op)
|
||||
return False, None, weight, True
|
||||
if op.inputs[0] in inputs and len(op.inputs) == 1 and \
|
||||
PrimLib.iter_type(op) <= PrimLib.BROADCAST:
|
||||
region_ops.append(op)
|
||||
return False, None, weight
|
||||
return False, None, weight, True
|
||||
#region fails to grow
|
||||
MAX_WEIGHT = 20
|
||||
if weight > MAX_WEIGHT or len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST:
|
||||
return False, None, weight
|
||||
return False, None, weight, False
|
||||
#region grows successfully
|
||||
weight = weight + 1
|
||||
region_ops.append(op)
|
||||
return True, op.inputs[0].op, weight
|
||||
return True, op.inputs[0].op, weight, False
|
||||
|
||||
def _find_cheap_regions(dom):
|
||||
sub = self.to_subgraph(dom)
|
||||
|
@ -480,13 +485,13 @@ class GraphSplitByPattern:
|
|||
grow = True
|
||||
candidate_op = output.op
|
||||
weight = 1
|
||||
result = False
|
||||
while grow:
|
||||
grow, candidate_op, weight = _grow_region(region_ops, candidate_op, weight, inputs)
|
||||
# region ends at input and not empty
|
||||
if region_ops and region_ops[-1].inputs[0] in inputs:
|
||||
grow, candidate_op, weight, result = _grow_region(region_ops, candidate_op, weight, inputs)
|
||||
if result:
|
||||
region_ops.reverse()
|
||||
# tensor size should equal or becomes larger(cast up, broadcast)
|
||||
if region_ops[0].inputs[0].get_size() > region_ops[-1].output.get_size():
|
||||
if region_ops[0].inputs and region_ops[0].inputs[0].get_size() > region_ops[-1].output.get_size():
|
||||
continue
|
||||
cheap_regions.append(region_ops)
|
||||
return cheap_regions
|
||||
|
|
Loading…
Reference in New Issue