!34057 modify for rewrite pattern

Merge pull request !34057 from 于振华/rewrite_pattern
This commit is contained in:
i-robot 2022-05-11 07:50:52 +00:00 committed by Gitee
commit c3647bc54b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 367 additions and 393 deletions

View File

@ -279,12 +279,9 @@ class PatternEngine:
RuntimeError: If `SymbolTree` has no return node.
"""
root: Node = stree.get_return_node()
if root is None:
raise RuntimeError("SymbolTree should be inited and has return node")
changed = False
# IR match
queue: [Node] = [root]
queue: [Node] = stree.get_inputs()
# Why need visited: we don't need or should not to visit same node multi-times because pre-visited node may
# already been erased from SymbolTree.
# When will we visit same node multi-times:
@ -314,20 +311,20 @@ class PatternEngine:
subtree = TreeNodeHelper.get_sub_tree(cur_node)
self.apply(subtree)
visited.append(cur_node)
queue.extend(cur_node.get_inputs())
queue.extend(cur_node.get_users())
continue
visited.append(cur_node)
matched, matched_dict = self._match(self._pattern, cur_node)
# not matched
if not matched or not PatternEngine._check_match(self._pattern, matched_dict):
queue.extend(cur_node.get_inputs())
queue.extend(cur_node.get_users())
continue
# matched
new_nodes: [Node] = []
if self._replacement is not None:
new_nodes: [Node] = self._replacement(self._pattern, self._is_chain, matched_dict)
if not new_nodes: # if replacement is empty, do nothing
queue.extend(cur_node.get_inputs())
queue.extend(cur_node.get_users())
else: # replace cur_node with new_nodes
changed = True
root = PatternEngine._multi_to_multi_replace(stree, cur_node, matched_dict, new_nodes)

View File

@ -56,7 +56,7 @@ class SymbolTree:
Get all nodes of corresponding network.
Returns:
A dict mapping from name of node to node.
A generator for node of current `SymbolTree`.
"""
for node in self._symbol_tree.nodes():
yield Node(node)
@ -76,20 +76,14 @@ class SymbolTree:
return None
return Node(node_impl)
def get_return_node(self) -> Node:
def get_inputs(self) -> [Node]:
"""
Get return node of current `SymbolTree`.
A `SymbolTree` can and should have one return node corresponding to return statement of forward method of
network.
Get 'input' nodes of current `SymbolTree`.
Returns:
An instance of node represents return node.
[Node]: The node list of the current 'Symboltree'.
"""
ret = self._symbol_tree.get_return_node()
if ret is None:
raise RuntimeError("SymbolTree is not well inited, can not find return node.")
return Node(ret)
return [Node(node_impl) for node_impl in self._symbol_tree.get_inputs()]
def before(self, node: Node):
"""

View File

@ -361,46 +361,6 @@ class Node:
normalized_args[sorted_kwarg[0]] = sorted_kwarg[1]
normalized_args_keys.append(sorted_kwarg[0])
def _get_normalized_args(self, args: [ScopedValue], kwargs: {str: ScopedValue}) -> dict:
"""
Merge args and kwargs to normalized args.
The keys of args are obtained from the construct function of type(self._instance).
Args:
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
Raises:
RuntimeError: Input args are invalid.
RuntimeError: Arg name already exist in kwargs.
Returns:
The normalized args.
"""
if not args:
args = []
if not kwargs:
kwargs = {}
normalized_args: dict = dict()
if self._instance and hasattr(type(self._instance), "construct"):
parameters = inspect.signature(type(self._instance).construct).parameters
names = Node._get_construct_arg_names(parameters)
Node._map_args_names(names, args, kwargs, self._normalized_args_keys, normalized_args)
else:
logger.debug("fail to get arg name from op, using arg_xx for args' name")
arg_temp_name, suffix = "arg", 0
for arg in args:
arg_key = "{}_{}".format(arg_temp_name, suffix)
while arg_key in kwargs.keys() or arg_key in normalized_args.keys():
suffix += 1
arg_key = "{}_{}".format(arg_temp_name, suffix)
normalized_args[arg_key] = arg
self._normalized_args_keys.append(arg_key)
for arg_key, value in kwargs.items():
normalized_args[arg_key] = value
self._normalized_args_keys.append(arg_key)
return normalized_args
@staticmethod
def _handle_custom_obj_in_args(args: [ScopedValue]) -> [ScopedValue]:
"""
@ -468,6 +428,27 @@ class Node:
raise RuntimeError("Invalid symbol type: ", target)
return results
@staticmethod
def _get_cell_or_prim_op_attribute(obj) -> dict:
"""
Find attributes of cell-op or primitive-op.
Args:
obj: A cell-op or a primitive-op.
Returns:
A dict represents attributes of input 'obj'.
"""
attributes = {}
if obj is None:
return attributes
for k, v in obj.__dict__.items():
if k.startswith("_"):
continue
attributes[k] = v
attributes["cls"] = obj.__class__
return attributes
def get_prev(self) -> 'Node':
"""
Get previous node of current node in source code order.
@ -477,10 +458,6 @@ class Node:
"""
return self._prev
def set_prev(self, prev):
"""Set previous node of current node in source code order. """
self._prev = prev
def get_next(self) -> 'Node':
"""
Get next node of current node in source code order.
@ -490,10 +467,6 @@ class Node:
"""
return self._next
def set_next(self, _next):
"""Set next node of current node in source code order."""
self._next = _next
def has_same_ast(self, node: Union['Node', ast.AST]) -> bool:
"""
Check if other node holds same ast node with self.
@ -505,7 +478,7 @@ class Node:
A bool.
"""
if isinstance(node, Node):
return self.has_same_ast(node.get_ast())
return self.has_same_ast(node._ast_node)
if isinstance(node, ast.AST):
return id(self._ast_node) == id(node)
return False
@ -536,165 +509,14 @@ class Node:
def set_belong_symbol_tree(self, symbol_tree):
self._belong_tree = symbol_tree
def _sync_assign_func_to_ast(self):
"""Sync func of ast.Call of ast.Assign from self._name when NodeType is CallCell or CallPrimitive."""
if self._ast_node is None:
return
assign_ast = self._ast_node
if not isinstance(assign_ast, ast.Assign):
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
call_ast = assign_ast.value
if not isinstance(call_ast, ast.Call):
raise TypeError("call_ast should be ast.Call, got: ", type(call_ast))
func_ast = call_ast.func
if not self._func.value:
if isinstance(func_ast, ast.Name):
func_ast.id = self._func.value
else:
call_ast.func = ast.Name(self._func.value, ast.Store())
else:
if isinstance(func_ast, ast.Attribute):
func_value = func_ast.value
if not isinstance(func_value, ast.Name):
raise RuntimeError("Only support ast.Name as value of attribute ", type(func_ast.value))
func_value.id = self._func.scope
func_ast.attr = self._func.value
else:
call_ast.func = ast.Attribute(ast.Name(self._func.scope, ast.Load()), self._func.value, ast.Store())
ast.fix_missing_locations(assign_ast)
def _sync_assign_targets_to_ast(self):
"""Sync targets of ast.Assign from self._targets when NodeType is CallCell, CallPrimitive or CallMethod."""
if self._ast_node is None:
return
assign_ast = self._ast_node
if not isinstance(assign_ast, ast.Assign):
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
# update targets
targets_ast = assign_ast.targets
if isinstance(targets_ast[0], ast.Tuple) and len(self._targets) != len(targets_ast[0].elts):
raise RuntimeError("self._targets should have the same length as targets_ast's elts")
if not isinstance(targets_ast[0], ast.Tuple) and len(self._targets) != len(targets_ast):
raise RuntimeError("self._targets should have targets_ast same length")
for i in range(0, len(self._targets)):
target = self._targets[i]
target_ast = targets_ast[0]
if isinstance(target_ast, ast.Name):
target_ast.id = target.value
elif isinstance(target_ast, ast.Tuple):
if not isinstance(target_ast.elts[i], ast.Name):
raise TypeError("target should be ast.Name, got:", type(target_ast.elts[i]))
target_ast.elts[i].id = target.value
else:
raise TypeError("target_ast should be ast.Name or ast.Tuple, got: ", type(target_ast))
target_ast.id = target.value
ast.fix_missing_locations(assign_ast)
def _sync_call_cell_args_to_ast(self):
"""Sync args of ast.Cell of ast.Assign from self._normalized_args when NodeType is CallCell or CallPrimitive."""
if self._ast_node is None:
return
assign_ast = self._ast_node
if not isinstance(assign_ast, ast.Assign):
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
assign_value = assign_ast.value
if not isinstance(assign_value, ast.Call):
return
keywords_ast = assign_value.keywords
args_ast = assign_value.args
if len(self._normalized_args_keys) != (len(keywords_ast) + len(args_ast)):
raise RuntimeError("ast keywords plus args len is not equal to self._normalized_args value")
for arg_index in range(self._args_num):
arg_ast = args_ast[arg_index]
AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[arg_index]), arg_ast)
# the order of kwargs may not the same as that in keywords_ast
keyword_map_index = {}
for index, keyword_ast in enumerate(keywords_ast):
keyword_map_index[keyword_ast.arg] = index
for keyword_index in range(self._kwargs_num):
key = self._normalized_args_keys[keyword_index + self._args_num]
AstModifier.update_arg_value(self._normalized_args.get(key),
keywords_ast[keyword_map_index.get(key)].value)
def _sync_call_method_args_to_ast(self):
"""Sync args of ast.Cell of ast.Assign from self._normalized_args when NodeType is CallMethod."""
if self._ast_node is None:
return
assign_ast = self._ast_node
if not isinstance(assign_ast, ast.Assign):
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
assign_value = assign_ast.value
if self._func == PASS_THROUGH_METHOD:
if isinstance(assign_value, ast.Name):
if len(self._normalized_args_keys) != 1:
raise RuntimeError("self._normalized_args_keys should have 1 elements")
arg = self._normalized_args.get(self._normalized_args_keys[0])
if arg.type != ValueType.NamingValue:
raise RuntimeError("arg.type should equal to ValueType.NamingValue")
if arg.scope != "":
raise RuntimeError("arg.scope should be empty")
assign_value.id = arg.value
elif isinstance(assign_value, ast.Attribute):
if len(self._normalized_args_keys) != 1:
raise RuntimeError("self._normalized_args_keys should have 1 elements")
arg = self._normalized_args.get(self._normalized_args_keys[0])
if arg.type != ValueType.NamingValue:
raise RuntimeError("arg.type should equal to ValueType.NamingValue")
assign_value.attr = arg.value
assign_value_value = assign_value.value
if not isinstance(assign_value_value, ast.Name):
raise RuntimeError("Only support ast.Name as value of attribute ", type(assign_value_value))
assign_value_value.id = arg.scope
else:
if len(self._normalized_args_keys) != 1:
raise RuntimeError("self._normalized_args_keys should have 1 elements")
arg = self._normalized_args.get(self._normalized_args_keys[0])
if arg.type not in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
raise RuntimeError("arg should be an IntValue, FloatValue or StringValue")
if arg.scope != "":
raise RuntimeError("arg.scope should be empty")
assign_value.value = arg.value
else:
raise RuntimeError("Only support pass_through method as call_method now, ", self._func.value)
def _sync_return_node_to_ast(self):
"""Sync return value of ast.Return from self._normalized_args when NodeType is Output."""
if self._ast_node is None:
return
return_ast = self._ast_node
if not isinstance(return_ast, ast.Return):
raise TypeError("return_ast should be ast.Return, got: ", type(return_ast))
# update args
return_value_ast = return_ast.value
if isinstance(return_value_ast, ast.Name):
if len(self._normalized_args_keys) != 1:
raise RuntimeError("self._normalized_args_keys should have 1 elements")
return_value_ast.id = self._normalized_args.get(self._normalized_args_keys[0]).value
elif isinstance(return_value_ast, ast.Tuple):
elements = return_value_ast.elts
if len(self._normalized_args.values()) != len(elements):
raise RuntimeError("self._normalized_args.values() should have elements same length")
for elt_index, elt in enumerate(elements):
if not isinstance(elt, ast.Name):
raise RuntimeError("Only support ast.Name as return value: ", elt)
arg = self._normalized_args.get(self._normalized_args_keys[elt_index])
if not isinstance(arg, ScopedValue):
raise TypeError("arg should be ScopedValue, got: ", type(arg))
elt.id = arg.value
else:
raise RuntimeError("Unsupported return value type: ", return_value_ast)
ast.fix_missing_locations(return_ast)
def isolate(self):
"""Link prev node to next node and isolate node from source code order list."""
origin_prev: Optional[Node] = self._prev
origin_next: Optional[Node] = self._next
if origin_prev is not None:
origin_prev.set_next(origin_next)
origin_prev._next = origin_next
if origin_next is not None:
origin_next.set_prev(origin_prev)
origin_next._prev = origin_prev
self._prev = None
self._next = None
@ -708,9 +530,9 @@ class Node:
node.isolate()
origin_prev: Optional[Node] = self._prev
if origin_prev is not None:
origin_prev.set_next(node)
node.set_prev(origin_prev)
node.set_next(self)
origin_prev._next = node
node._prev = origin_prev
node._next = self
self._prev = node
def insert_after(self, node: 'Node'):
@ -723,10 +545,10 @@ class Node:
node.isolate()
origin_next: Optional[Node] = self._next
self._next = node
node.set_prev(self)
node.set_next(origin_next)
node._prev = self
node._next = origin_next
if origin_next is not None:
origin_next.set_prev(node)
origin_next._prev = node
def get_inputs(self) -> ['Node']:
"""
@ -867,15 +689,6 @@ class Node:
"""
return self._instance
def _sync_arg(self):
"""Sync _normalized_args to corresponding ast node when updated."""
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree):
self._sync_call_cell_args_to_ast()
elif self._node_type == NodeType.Output:
self._sync_return_node_to_ast()
elif self._node_type == NodeType.CallMethod:
self._sync_call_method_args_to_ast()
def set_arg_by_node(self, arg_idx: int, node: 'Node', out_idx: Optional[int] = None):
"""
Set argument by another Node.
@ -896,12 +709,12 @@ class Node:
if arg_idx >= self._args_num or arg_idx < 0:
raise RuntimeError("arg_idx out of range: ", arg_idx)
if out_idx is None:
if len(node.get_targets()) != 1:
if len(node._targets) != 1:
raise RuntimeError("node should has one output when out_idx is not provided")
out_idx = 0
if out_idx >= len(node.get_targets()):
if out_idx >= len(node._targets):
raise RuntimeError("out_idx out of range: ", out_idx)
new_arg = node.get_targets()[out_idx]
new_arg = node._targets[out_idx]
self._normalized_args[self._normalized_args_keys[arg_idx]] = new_arg
self._sync_arg()
@ -1091,26 +904,205 @@ class Node:
"""
return self._attribute.get(key)
@staticmethod
def _get_cell_or_prim_op_attribute(obj) -> dict:
def _get_normalized_args(self, args: [ScopedValue], kwargs: {str: ScopedValue}) -> dict:
"""
Find attributes of cell-op or primitive-op.
Merge args and kwargs to normalized args.
The keys of args are obtained from the construct function of type(self._instance).
Args:
obj: A cell-op or a primitive-op.
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
Raises:
RuntimeError: Input args are invalid.
RuntimeError: Arg name already exist in kwargs.
Returns:
A dict represents attributes of input 'obj'.
The normalized args.
"""
attributes = {}
if obj is None:
return attributes
for k, v in obj.__dict__.items():
if k.startswith("_"):
continue
attributes[k] = v
attributes["cls"] = obj.__class__
return attributes
if not args:
args = []
if not kwargs:
kwargs = {}
normalized_args: dict = dict()
if self._instance and hasattr(type(self._instance), "construct"):
parameters = inspect.signature(type(self._instance).construct).parameters
names = Node._get_construct_arg_names(parameters)
Node._map_args_names(names, args, kwargs, self._normalized_args_keys, normalized_args)
else:
logger.debug("fail to get arg name from op, using arg_xx for args' name")
arg_temp_name, suffix = "arg", 0
for arg in args:
arg_key = "{}_{}".format(arg_temp_name, suffix)
while arg_key in kwargs.keys() or arg_key in normalized_args.keys():
suffix += 1
arg_key = "{}_{}".format(arg_temp_name, suffix)
normalized_args[arg_key] = arg
self._normalized_args_keys.append(arg_key)
for arg_key, value in kwargs.items():
normalized_args[arg_key] = value
self._normalized_args_keys.append(arg_key)
return normalized_args
def _sync_assign_func_to_ast(self):
"""Sync func of ast.Call of ast.Assign from self._name when NodeType is CallCell or CallPrimitive."""
if self._ast_node is None:
return
assign_ast = self._ast_node
if not isinstance(assign_ast, ast.Assign):
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
call_ast = assign_ast.value
if not isinstance(call_ast, ast.Call):
raise TypeError("call_ast should be ast.Call, got: ", type(call_ast))
func_ast = call_ast.func
if not self._func.value:
if isinstance(func_ast, ast.Name):
func_ast.id = self._func.value
else:
call_ast.func = ast.Name(self._func.value, ast.Store())
else:
if isinstance(func_ast, ast.Attribute):
func_value = func_ast.value
if not isinstance(func_value, ast.Name):
raise RuntimeError("Only support ast.Name as value of attribute ", type(func_ast.value))
func_value.id = self._func.scope
func_ast.attr = self._func.value
else:
call_ast.func = ast.Attribute(ast.Name(self._func.scope, ast.Load()), self._func.value, ast.Store())
ast.fix_missing_locations(assign_ast)
def _sync_assign_targets_to_ast(self):
"""Sync targets of ast.Assign from self._targets when NodeType is CallCell, CallPrimitive or CallMethod."""
if self._ast_node is None:
return
assign_ast = self._ast_node
if not isinstance(assign_ast, ast.Assign):
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
# update targets
targets_ast = assign_ast.targets
if isinstance(targets_ast[0], ast.Tuple) and len(self._targets) != len(targets_ast[0].elts):
raise RuntimeError("self._targets should have the same length as targets_ast's elts")
if not isinstance(targets_ast[0], ast.Tuple) and len(self._targets) != len(targets_ast):
raise RuntimeError("self._targets should have targets_ast same length")
for i in range(0, len(self._targets)):
target = self._targets[i]
target_ast = targets_ast[0]
if isinstance(target_ast, ast.Name):
target_ast.id = target.value
elif isinstance(target_ast, ast.Tuple):
if not isinstance(target_ast.elts[i], ast.Name):
raise TypeError("target should be ast.Name, got:", type(target_ast.elts[i]))
target_ast.elts[i].id = target.value
else:
raise TypeError("target_ast should be ast.Name or ast.Tuple, got: ", type(target_ast))
target_ast.id = target.value
ast.fix_missing_locations(assign_ast)
def _sync_call_cell_args_to_ast(self):
"""Sync args of ast.Cell of ast.Assign from self._normalized_args when NodeType is CallCell or CallPrimitive."""
if self._ast_node is None:
return
assign_ast = self._ast_node
if not isinstance(assign_ast, ast.Assign):
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
assign_value = assign_ast.value
if not isinstance(assign_value, ast.Call):
return
keywords_ast = assign_value.keywords
args_ast = assign_value.args
if len(self._normalized_args_keys) != (len(keywords_ast) + len(args_ast)):
raise RuntimeError("ast keywords plus args len is not equal to self._normalized_args value")
for arg_index in range(self._args_num):
arg_ast = args_ast[arg_index]
AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[arg_index]), arg_ast)
# the order of kwargs may not the same as that in keywords_ast
keyword_map_index = {}
for index, keyword_ast in enumerate(keywords_ast):
keyword_map_index[keyword_ast.arg] = index
for keyword_index in range(self._kwargs_num):
key = self._normalized_args_keys[keyword_index + self._args_num]
AstModifier.update_arg_value(self._normalized_args.get(key),
keywords_ast[keyword_map_index.get(key)].value)
def _sync_call_method_args_to_ast(self):
"""Sync args of ast.Cell of ast.Assign from self._normalized_args when NodeType is CallMethod."""
if self._ast_node is None:
return
assign_ast = self._ast_node
if not isinstance(assign_ast, ast.Assign):
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
assign_value = assign_ast.value
if self._func == PASS_THROUGH_METHOD:
if isinstance(assign_value, ast.Name):
if len(self._normalized_args_keys) != 1:
raise RuntimeError("self._normalized_args_keys should have 1 elements")
arg = self._normalized_args.get(self._normalized_args_keys[0])
if arg.type != ValueType.NamingValue:
raise RuntimeError("arg.type should equal to ValueType.NamingValue")
if arg.scope != "":
raise RuntimeError("arg.scope should be empty")
assign_value.id = arg.value
elif isinstance(assign_value, ast.Attribute):
if len(self._normalized_args_keys) != 1:
raise RuntimeError("self._normalized_args_keys should have 1 elements")
arg = self._normalized_args.get(self._normalized_args_keys[0])
if arg.type != ValueType.NamingValue:
raise RuntimeError("arg.type should equal to ValueType.NamingValue")
assign_value.attr = arg.value
assign_value_value = assign_value.value
if not isinstance(assign_value_value, ast.Name):
raise RuntimeError("Only support ast.Name as value of attribute ", type(assign_value_value))
assign_value_value.id = arg.scope
else:
if len(self._normalized_args_keys) != 1:
raise RuntimeError("self._normalized_args_keys should have 1 elements")
arg = self._normalized_args.get(self._normalized_args_keys[0])
if arg.type not in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
raise RuntimeError("arg should be an IntValue, FloatValue or StringValue")
if arg.scope != "":
raise RuntimeError("arg.scope should be empty")
assign_value.value = arg.value
else:
raise RuntimeError("Only support pass_through method as call_method now, ", self._func.value)
def _sync_return_node_to_ast(self):
"""Sync return value of ast.Return from self._normalized_args when NodeType is Output."""
if self._ast_node is None:
return
return_ast = self._ast_node
if not isinstance(return_ast, ast.Return):
raise TypeError("return_ast should be ast.Return, got: ", type(return_ast))
# update args
return_value_ast = return_ast.value
if isinstance(return_value_ast, ast.Name):
if len(self._normalized_args_keys) != 1:
raise RuntimeError("self._normalized_args_keys should have 1 elements")
return_value_ast.id = self._normalized_args.get(self._normalized_args_keys[0]).value
elif isinstance(return_value_ast, ast.Tuple):
elements = return_value_ast.elts
if len(self._normalized_args.values()) != len(elements):
raise RuntimeError("self._normalized_args.values() should have elements same length")
for elt_index, elt in enumerate(elements):
if not isinstance(elt, ast.Name):
raise RuntimeError("Only support ast.Name as return value: ", elt)
arg = self._normalized_args.get(self._normalized_args_keys[elt_index])
if not isinstance(arg, ScopedValue):
raise TypeError("arg should be ScopedValue, got: ", type(arg))
elt.id = arg.value
else:
raise RuntimeError("Unsupported return value type: ", return_value_ast)
ast.fix_missing_locations(return_ast)
def _sync_arg(self):
"""Sync _normalized_args to corresponding ast node when updated."""
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree):
self._sync_call_cell_args_to_ast()
elif self._node_type == NodeType.Output:
self._sync_return_node_to_ast()
elif self._node_type == NodeType.CallMethod:
self._sync_call_method_args_to_ast()
class TreeNode(Node):

View File

@ -116,13 +116,77 @@ class SymbolTree(Observer, Observable):
self._modified = False
self._node_visitor = None
@staticmethod
def _link_nodes_and_find_root(nodes: [Node]) -> Node:
"""
Find inputs for all nodes created by Replacement according to their targets and arguments.
Find root node of all nodes created by Replacement. One and Only one root should be found.
Args:
nodes (list[Node]): A list of instance of Node created by Replacement.
Returns:
An instance of Node represents root of input nodes.
"""
consumers: [ScopedValue] = []
target_dict: {ScopedValue: Node} = {}
for node in nodes:
consumers.extend(node.get_args())
for _, arg in node.get_kwargs():
consumers.append(arg)
for target in node.get_targets():
if target_dict.get(target) is not None:
raise RuntimeError("Target of node duplicated")
target_dict[target] = node
# find root node
root = None
for node in nodes:
used = 0
for target in node.get_targets():
if target in consumers:
used += 1
break
if used == 0:
if root is not None:
raise RuntimeError("Replacement should only has one root")
root = node
if root is None:
raise RuntimeError("No root node found in replacement nodes")
# link node's input
for node in nodes:
inputs = []
for _, arg in node.get_normalized_args().items():
node_input: Node = target_dict.get(arg)
inputs.append(node_input)
node.set_inputs(inputs)
return root
@staticmethod
def _find_all_class_in_symboltree(stree: 'SymbolTree', seen_class: {type, str}, allow_class_name: [], replacers):
"""Find all non-duplicated class name of SymbolTree recursively."""
replacer = AstReplacer(stree._class_ast)
replacers.append(replacer)
for node in stree.nodes():
if not isinstance(node, TreeNode):
continue
sub_stree: SymbolTree = node.symbol_tree
SymbolTree._find_all_class_in_symboltree(sub_stree, seen_class, allow_class_name, replacers)
# all modified ast.ClassDef should export to code
if sub_stree._modified:
allow_class_name.append(sub_stree._class_ast.name)
continue
# all un-modified ast.ClassDef only keep one instance
seen_cls_name = seen_class.get(type(sub_stree.get_origin_network()))
if seen_cls_name is not None:
replacer.replace_all(sub_stree._class_ast.name, seen_cls_name)
else:
seen_class[type(sub_stree.get_origin_network())] = sub_stree._class_ast.name
allow_class_name.append(sub_stree._class_ast.name)
def finish_build(self):
self.add_event(Event.TopologicalChangeEvent)
def _on_change(self, event: Event):
self._modified = True
self.changed(event)
def get_ori_cls_name(self) -> str:
"""
Get class name of original network.
@ -234,15 +298,6 @@ class SymbolTree(Observer, Observable):
"""
return self._head
def get_return_node(self):
"""
Getter of `_return` which represents return statement of forward method of network.
Returns:
An instance of node.
"""
return self._return
def get_origin_network(self):
"""
Getter of `_origin_network`.
@ -295,14 +350,6 @@ class SymbolTree(Observer, Observable):
return self._nodes.get(node_name)
def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
if isinstance(node_or_name, Node):
result = self.get_node(node_or_name.get_name())
return result if result is node_or_name else None
if isinstance(node_or_name, str):
return self.get_node(node_or_name)
return None
def get_node_inputs(self, node_or_name: Union[Node, str]) -> [Node]:
"""
Getter of inputs in topological relation of current 'node_or_name'.
@ -337,6 +384,8 @@ class SymbolTree(Observer, Observable):
if real_node is None:
logger.info("Node(%s) is not belong to current SymbolTree", node_or_name)
return []
if real_node.get_node_type() == NodeType.Output:
return []
return self._topo_mgr.get_node_users(node_or_name)
def before(self, node_or_name: Union[Node, str]) -> Position:
@ -628,104 +677,6 @@ class SymbolTree(Observer, Observable):
self._node_visitor.remove_node(node)
return node
def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node:
"""
Insert a node-tree into SymbolTree.
Note:
Inputs of intra sub-tree nodes need to be welly set.
Inputs of inter sub-tree nodes will be updated by Rewrite automatically.
Args:
position (Position): A Position indicates an insert position point.
root (Node): An instance of node as root of node-tree to be inserted in.
insert_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
True.
Returns:
An instance of node as root node of node-tree which has been inserted into SymbolTree.
Raises:
RuntimeError: If 'position' is not in current SymbolTree.
"""
# if position not in current SymbolTree
if position.symbol_tree is not self:
raise RuntimeError("Position is not in current SymbolTree: ", position)
queue: [Node] = [root]
todos: [] = []
inputs_list: [] = []
while queue:
cur_node = queue.pop(0)
if cur_node in todos:
continue
todos.append(cur_node)
node_inputs = cur_node.get_inputs()
inputs_list.append(node_inputs)
for node_input in node_inputs:
if node_input is not None:
queue.append(node_input)
todos.reverse()
inputs_list.reverse()
for index, todo in enumerate(todos):
self.insert_node(position, todo, insert_to_ast)
position = self.after(todo)
# relink input of node
original_inputs = inputs_list[index]
for arg_idx, original_input in enumerate(original_inputs):
if original_input is not None:
self.set_node_arg_by_node(todo, arg_idx, original_input)
return root
@staticmethod
def _link_nodes_and_find_root(nodes: [Node]) -> Node:
"""
Find inputs for all nodes created by Replacement according to their targets and arguments.
Find root node of all nodes created by Replacement. One and Only one root should be found.
Args:
nodes (list[Node]): A list of instance of Node created by Replacement.
Returns:
An instance of Node represents root of input nodes.
"""
consumers: [ScopedValue] = []
target_dict: {ScopedValue: Node} = {}
for node in nodes:
consumers.extend(node.get_args())
for _, arg in node.get_kwargs():
consumers.append(arg)
for target in node.get_targets():
if target_dict.get(target) is not None:
raise RuntimeError("Target of node duplicated")
target_dict[target] = node
# find root node
root = None
for node in nodes:
used = 0
for target in node.get_targets():
if target in consumers:
used += 1
if used == 0:
if root is not None:
raise RuntimeError("Replacement should only has one root")
root = node
if root is None:
raise RuntimeError("No root node found in replacement nodes")
# link node's input
for node in nodes:
inputs = []
for _, arg in node.get_normalized_args().items():
node_input: Node = target_dict.get(arg)
if node_input is None:
inputs.append(None)
else:
inputs.append(node_input)
node.set_inputs(inputs)
return root
def replace(self, old_node: Node, new_nodes: [Node]) -> Node:
"""
Replace an old_node with a node_tree. 'new_node' is the root node of the node_tree.
@ -834,28 +785,6 @@ class SymbolTree(Observer, Observable):
dump_st = SymbolTreeDumper(self)
dump_st.dump()
@staticmethod
def _find_all_class_in_symboltree(stree: 'SymbolTree', seen_class: {type, str}, allow_class_name: [], replacers):
"""Find all non-duplicated class name of SymbolTree recursively."""
replacer = AstReplacer(stree._class_ast)
replacers.append(replacer)
for node in stree.nodes():
if not isinstance(node, TreeNode):
continue
sub_stree: SymbolTree = node.symbol_tree
SymbolTree._find_all_class_in_symboltree(sub_stree, seen_class, allow_class_name, replacers)
# all modified ast.ClassDef should export to code
if sub_stree._modified:
allow_class_name.append(sub_stree._class_ast.name)
continue
# all un-modified ast.ClassDef only keep one instance
seen_cls_name = seen_class.get(type(sub_stree.get_origin_network()))
if seen_cls_name is not None:
replacer.replace_all(sub_stree._class_ast.name, seen_cls_name)
else:
seen_class[type(sub_stree.get_origin_network())] = sub_stree._class_ast.name
allow_class_name.append(sub_stree._class_ast.name)
def get_code(self) -> str:
"""
Get source code of modified network.
@ -897,6 +826,64 @@ class SymbolTree(Observer, Observable):
cls = self._get_cls_through_file()
return cls(self._global_vars)
def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
if isinstance(node_or_name, Node):
result = self.get_node(node_or_name.get_name())
return result if result is node_or_name else None
if isinstance(node_or_name, str):
return self.get_node(node_or_name)
return None
def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node:
"""
Insert a node-tree into SymbolTree.
Note:
Inputs of intra sub-tree nodes need to be welly set.
Inputs of inter sub-tree nodes will be updated by Rewrite automatically.
Args:
position (Position): A Position indicates an insert position point.
root (Node): An instance of node as root of node-tree to be inserted in.
insert_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
True.
Returns:
An instance of node as root node of node-tree which has been inserted into SymbolTree.
Raises:
RuntimeError: If 'position' is not in current SymbolTree.
"""
# if position not in current SymbolTree
if position.symbol_tree is not self:
raise RuntimeError("Position is not in current SymbolTree: ", position)
queue: [Node] = [root]
todos: [] = []
inputs_list: [] = []
while queue:
cur_node = queue.pop(0)
if cur_node in todos:
continue
todos.append(cur_node)
node_inputs = cur_node.get_inputs()
inputs_list.append(node_inputs)
for node_input in node_inputs:
if node_input is not None:
queue.append(node_input)
todos.reverse()
inputs_list.reverse()
for index, todo in enumerate(todos):
self.insert_node(position, todo, insert_to_ast)
position = self.after(todo)
# relink input of node
original_inputs = inputs_list[index]
for arg_idx, original_input in enumerate(original_inputs):
if original_input is not None:
self.set_node_arg_by_node(todo, arg_idx, original_input)
return root
def _unique_targets(self, node: Node):
"""
Unique targets of node by _target_namer.
@ -1025,3 +1012,7 @@ class SymbolTree(Observer, Observable):
if network_cls is None:
raise RuntimeError("Can not find network class:", self._opt_cls_name)
return network_cls
def _on_change(self, event: Event):
self._modified = True
self.changed(event)