!34057 modify for rewrite pattern
Merge pull request !34057 from 于振华/rewrite_pattern
This commit is contained in:
commit
c3647bc54b
|
@ -279,12 +279,9 @@ class PatternEngine:
|
|||
RuntimeError: If `SymbolTree` has no return node.
|
||||
"""
|
||||
|
||||
root: Node = stree.get_return_node()
|
||||
if root is None:
|
||||
raise RuntimeError("SymbolTree should be inited and has return node")
|
||||
changed = False
|
||||
# IR match
|
||||
queue: [Node] = [root]
|
||||
queue: [Node] = stree.get_inputs()
|
||||
# Why need visited: we don't need or should not to visit same node multi-times because pre-visited node may
|
||||
# already been erased from SymbolTree.
|
||||
# When will we visit same node multi-times:
|
||||
|
@ -314,20 +311,20 @@ class PatternEngine:
|
|||
subtree = TreeNodeHelper.get_sub_tree(cur_node)
|
||||
self.apply(subtree)
|
||||
visited.append(cur_node)
|
||||
queue.extend(cur_node.get_inputs())
|
||||
queue.extend(cur_node.get_users())
|
||||
continue
|
||||
visited.append(cur_node)
|
||||
matched, matched_dict = self._match(self._pattern, cur_node)
|
||||
# not matched
|
||||
if not matched or not PatternEngine._check_match(self._pattern, matched_dict):
|
||||
queue.extend(cur_node.get_inputs())
|
||||
queue.extend(cur_node.get_users())
|
||||
continue
|
||||
# matched
|
||||
new_nodes: [Node] = []
|
||||
if self._replacement is not None:
|
||||
new_nodes: [Node] = self._replacement(self._pattern, self._is_chain, matched_dict)
|
||||
if not new_nodes: # if replacement is empty, do nothing
|
||||
queue.extend(cur_node.get_inputs())
|
||||
queue.extend(cur_node.get_users())
|
||||
else: # replace cur_node with new_nodes
|
||||
changed = True
|
||||
root = PatternEngine._multi_to_multi_replace(stree, cur_node, matched_dict, new_nodes)
|
||||
|
|
|
@ -56,7 +56,7 @@ class SymbolTree:
|
|||
Get all nodes of corresponding network.
|
||||
|
||||
Returns:
|
||||
A dict mapping from name of node to node.
|
||||
A generator for node of current `SymbolTree`.
|
||||
"""
|
||||
for node in self._symbol_tree.nodes():
|
||||
yield Node(node)
|
||||
|
@ -76,20 +76,14 @@ class SymbolTree:
|
|||
return None
|
||||
return Node(node_impl)
|
||||
|
||||
def get_return_node(self) -> Node:
|
||||
def get_inputs(self) -> [Node]:
|
||||
"""
|
||||
Get return node of current `SymbolTree`.
|
||||
|
||||
A `SymbolTree` can and should have one return node corresponding to return statement of forward method of
|
||||
network.
|
||||
Get 'input' nodes of current `SymbolTree`.
|
||||
|
||||
Returns:
|
||||
An instance of node represents return node.
|
||||
[Node]: The node list of the current 'Symboltree'.
|
||||
"""
|
||||
ret = self._symbol_tree.get_return_node()
|
||||
if ret is None:
|
||||
raise RuntimeError("SymbolTree is not well inited, can not find return node.")
|
||||
return Node(ret)
|
||||
return [Node(node_impl) for node_impl in self._symbol_tree.get_inputs()]
|
||||
|
||||
def before(self, node: Node):
|
||||
"""
|
||||
|
|
|
@ -361,46 +361,6 @@ class Node:
|
|||
normalized_args[sorted_kwarg[0]] = sorted_kwarg[1]
|
||||
normalized_args_keys.append(sorted_kwarg[0])
|
||||
|
||||
def _get_normalized_args(self, args: [ScopedValue], kwargs: {str: ScopedValue}) -> dict:
|
||||
"""
|
||||
Merge args and kwargs to normalized args.
|
||||
The keys of args are obtained from the construct function of type(self._instance).
|
||||
|
||||
Args:
|
||||
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
||||
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Input args are invalid.
|
||||
RuntimeError: Arg name already exist in kwargs.
|
||||
|
||||
Returns:
|
||||
The normalized args.
|
||||
"""
|
||||
if not args:
|
||||
args = []
|
||||
if not kwargs:
|
||||
kwargs = {}
|
||||
normalized_args: dict = dict()
|
||||
if self._instance and hasattr(type(self._instance), "construct"):
|
||||
parameters = inspect.signature(type(self._instance).construct).parameters
|
||||
names = Node._get_construct_arg_names(parameters)
|
||||
Node._map_args_names(names, args, kwargs, self._normalized_args_keys, normalized_args)
|
||||
else:
|
||||
logger.debug("fail to get arg name from op, using arg_xx for args' name")
|
||||
arg_temp_name, suffix = "arg", 0
|
||||
for arg in args:
|
||||
arg_key = "{}_{}".format(arg_temp_name, suffix)
|
||||
while arg_key in kwargs.keys() or arg_key in normalized_args.keys():
|
||||
suffix += 1
|
||||
arg_key = "{}_{}".format(arg_temp_name, suffix)
|
||||
normalized_args[arg_key] = arg
|
||||
self._normalized_args_keys.append(arg_key)
|
||||
for arg_key, value in kwargs.items():
|
||||
normalized_args[arg_key] = value
|
||||
self._normalized_args_keys.append(arg_key)
|
||||
return normalized_args
|
||||
|
||||
@staticmethod
|
||||
def _handle_custom_obj_in_args(args: [ScopedValue]) -> [ScopedValue]:
|
||||
"""
|
||||
|
@ -468,6 +428,27 @@ class Node:
|
|||
raise RuntimeError("Invalid symbol type: ", target)
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _get_cell_or_prim_op_attribute(obj) -> dict:
|
||||
"""
|
||||
Find attributes of cell-op or primitive-op.
|
||||
|
||||
Args:
|
||||
obj: A cell-op or a primitive-op.
|
||||
|
||||
Returns:
|
||||
A dict represents attributes of input 'obj'.
|
||||
"""
|
||||
attributes = {}
|
||||
if obj is None:
|
||||
return attributes
|
||||
for k, v in obj.__dict__.items():
|
||||
if k.startswith("_"):
|
||||
continue
|
||||
attributes[k] = v
|
||||
attributes["cls"] = obj.__class__
|
||||
return attributes
|
||||
|
||||
def get_prev(self) -> 'Node':
|
||||
"""
|
||||
Get previous node of current node in source code order.
|
||||
|
@ -477,10 +458,6 @@ class Node:
|
|||
"""
|
||||
return self._prev
|
||||
|
||||
def set_prev(self, prev):
|
||||
"""Set previous node of current node in source code order. """
|
||||
self._prev = prev
|
||||
|
||||
def get_next(self) -> 'Node':
|
||||
"""
|
||||
Get next node of current node in source code order.
|
||||
|
@ -490,10 +467,6 @@ class Node:
|
|||
"""
|
||||
return self._next
|
||||
|
||||
def set_next(self, _next):
|
||||
"""Set next node of current node in source code order."""
|
||||
self._next = _next
|
||||
|
||||
def has_same_ast(self, node: Union['Node', ast.AST]) -> bool:
|
||||
"""
|
||||
Check if other node holds same ast node with self.
|
||||
|
@ -505,7 +478,7 @@ class Node:
|
|||
A bool.
|
||||
"""
|
||||
if isinstance(node, Node):
|
||||
return self.has_same_ast(node.get_ast())
|
||||
return self.has_same_ast(node._ast_node)
|
||||
if isinstance(node, ast.AST):
|
||||
return id(self._ast_node) == id(node)
|
||||
return False
|
||||
|
@ -536,165 +509,14 @@ class Node:
|
|||
def set_belong_symbol_tree(self, symbol_tree):
|
||||
self._belong_tree = symbol_tree
|
||||
|
||||
def _sync_assign_func_to_ast(self):
|
||||
"""Sync func of ast.Call of ast.Assign from self._name when NodeType is CallCell or CallPrimitive."""
|
||||
if self._ast_node is None:
|
||||
return
|
||||
assign_ast = self._ast_node
|
||||
if not isinstance(assign_ast, ast.Assign):
|
||||
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
|
||||
call_ast = assign_ast.value
|
||||
if not isinstance(call_ast, ast.Call):
|
||||
raise TypeError("call_ast should be ast.Call, got: ", type(call_ast))
|
||||
func_ast = call_ast.func
|
||||
if not self._func.value:
|
||||
if isinstance(func_ast, ast.Name):
|
||||
func_ast.id = self._func.value
|
||||
else:
|
||||
call_ast.func = ast.Name(self._func.value, ast.Store())
|
||||
else:
|
||||
if isinstance(func_ast, ast.Attribute):
|
||||
func_value = func_ast.value
|
||||
if not isinstance(func_value, ast.Name):
|
||||
raise RuntimeError("Only support ast.Name as value of attribute ", type(func_ast.value))
|
||||
func_value.id = self._func.scope
|
||||
func_ast.attr = self._func.value
|
||||
else:
|
||||
call_ast.func = ast.Attribute(ast.Name(self._func.scope, ast.Load()), self._func.value, ast.Store())
|
||||
ast.fix_missing_locations(assign_ast)
|
||||
|
||||
def _sync_assign_targets_to_ast(self):
|
||||
"""Sync targets of ast.Assign from self._targets when NodeType is CallCell, CallPrimitive or CallMethod."""
|
||||
if self._ast_node is None:
|
||||
return
|
||||
assign_ast = self._ast_node
|
||||
if not isinstance(assign_ast, ast.Assign):
|
||||
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
|
||||
# update targets
|
||||
targets_ast = assign_ast.targets
|
||||
if isinstance(targets_ast[0], ast.Tuple) and len(self._targets) != len(targets_ast[0].elts):
|
||||
raise RuntimeError("self._targets should have the same length as targets_ast's elts")
|
||||
if not isinstance(targets_ast[0], ast.Tuple) and len(self._targets) != len(targets_ast):
|
||||
raise RuntimeError("self._targets should have targets_ast same length")
|
||||
for i in range(0, len(self._targets)):
|
||||
target = self._targets[i]
|
||||
target_ast = targets_ast[0]
|
||||
if isinstance(target_ast, ast.Name):
|
||||
target_ast.id = target.value
|
||||
elif isinstance(target_ast, ast.Tuple):
|
||||
if not isinstance(target_ast.elts[i], ast.Name):
|
||||
raise TypeError("target should be ast.Name, got:", type(target_ast.elts[i]))
|
||||
target_ast.elts[i].id = target.value
|
||||
else:
|
||||
raise TypeError("target_ast should be ast.Name or ast.Tuple, got: ", type(target_ast))
|
||||
target_ast.id = target.value
|
||||
ast.fix_missing_locations(assign_ast)
|
||||
|
||||
def _sync_call_cell_args_to_ast(self):
|
||||
"""Sync args of ast.Cell of ast.Assign from self._normalized_args when NodeType is CallCell or CallPrimitive."""
|
||||
if self._ast_node is None:
|
||||
return
|
||||
assign_ast = self._ast_node
|
||||
if not isinstance(assign_ast, ast.Assign):
|
||||
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
|
||||
assign_value = assign_ast.value
|
||||
if not isinstance(assign_value, ast.Call):
|
||||
return
|
||||
keywords_ast = assign_value.keywords
|
||||
args_ast = assign_value.args
|
||||
if len(self._normalized_args_keys) != (len(keywords_ast) + len(args_ast)):
|
||||
raise RuntimeError("ast keywords plus args len is not equal to self._normalized_args value")
|
||||
|
||||
for arg_index in range(self._args_num):
|
||||
arg_ast = args_ast[arg_index]
|
||||
AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[arg_index]), arg_ast)
|
||||
|
||||
# the order of kwargs may not the same as that in keywords_ast
|
||||
keyword_map_index = {}
|
||||
for index, keyword_ast in enumerate(keywords_ast):
|
||||
keyword_map_index[keyword_ast.arg] = index
|
||||
for keyword_index in range(self._kwargs_num):
|
||||
key = self._normalized_args_keys[keyword_index + self._args_num]
|
||||
AstModifier.update_arg_value(self._normalized_args.get(key),
|
||||
keywords_ast[keyword_map_index.get(key)].value)
|
||||
|
||||
def _sync_call_method_args_to_ast(self):
|
||||
"""Sync args of ast.Cell of ast.Assign from self._normalized_args when NodeType is CallMethod."""
|
||||
if self._ast_node is None:
|
||||
return
|
||||
assign_ast = self._ast_node
|
||||
if not isinstance(assign_ast, ast.Assign):
|
||||
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
|
||||
assign_value = assign_ast.value
|
||||
if self._func == PASS_THROUGH_METHOD:
|
||||
if isinstance(assign_value, ast.Name):
|
||||
if len(self._normalized_args_keys) != 1:
|
||||
raise RuntimeError("self._normalized_args_keys should have 1 elements")
|
||||
arg = self._normalized_args.get(self._normalized_args_keys[0])
|
||||
if arg.type != ValueType.NamingValue:
|
||||
raise RuntimeError("arg.type should equal to ValueType.NamingValue")
|
||||
if arg.scope != "":
|
||||
raise RuntimeError("arg.scope should be empty")
|
||||
assign_value.id = arg.value
|
||||
elif isinstance(assign_value, ast.Attribute):
|
||||
if len(self._normalized_args_keys) != 1:
|
||||
raise RuntimeError("self._normalized_args_keys should have 1 elements")
|
||||
arg = self._normalized_args.get(self._normalized_args_keys[0])
|
||||
if arg.type != ValueType.NamingValue:
|
||||
raise RuntimeError("arg.type should equal to ValueType.NamingValue")
|
||||
assign_value.attr = arg.value
|
||||
assign_value_value = assign_value.value
|
||||
if not isinstance(assign_value_value, ast.Name):
|
||||
raise RuntimeError("Only support ast.Name as value of attribute ", type(assign_value_value))
|
||||
assign_value_value.id = arg.scope
|
||||
else:
|
||||
if len(self._normalized_args_keys) != 1:
|
||||
raise RuntimeError("self._normalized_args_keys should have 1 elements")
|
||||
arg = self._normalized_args.get(self._normalized_args_keys[0])
|
||||
if arg.type not in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
|
||||
raise RuntimeError("arg should be an IntValue, FloatValue or StringValue")
|
||||
if arg.scope != "":
|
||||
raise RuntimeError("arg.scope should be empty")
|
||||
assign_value.value = arg.value
|
||||
else:
|
||||
raise RuntimeError("Only support pass_through method as call_method now, ", self._func.value)
|
||||
|
||||
def _sync_return_node_to_ast(self):
|
||||
"""Sync return value of ast.Return from self._normalized_args when NodeType is Output."""
|
||||
if self._ast_node is None:
|
||||
return
|
||||
return_ast = self._ast_node
|
||||
if not isinstance(return_ast, ast.Return):
|
||||
raise TypeError("return_ast should be ast.Return, got: ", type(return_ast))
|
||||
# update args
|
||||
return_value_ast = return_ast.value
|
||||
if isinstance(return_value_ast, ast.Name):
|
||||
if len(self._normalized_args_keys) != 1:
|
||||
raise RuntimeError("self._normalized_args_keys should have 1 elements")
|
||||
return_value_ast.id = self._normalized_args.get(self._normalized_args_keys[0]).value
|
||||
elif isinstance(return_value_ast, ast.Tuple):
|
||||
elements = return_value_ast.elts
|
||||
if len(self._normalized_args.values()) != len(elements):
|
||||
raise RuntimeError("self._normalized_args.values() should have elements same length")
|
||||
for elt_index, elt in enumerate(elements):
|
||||
if not isinstance(elt, ast.Name):
|
||||
raise RuntimeError("Only support ast.Name as return value: ", elt)
|
||||
arg = self._normalized_args.get(self._normalized_args_keys[elt_index])
|
||||
if not isinstance(arg, ScopedValue):
|
||||
raise TypeError("arg should be ScopedValue, got: ", type(arg))
|
||||
elt.id = arg.value
|
||||
else:
|
||||
raise RuntimeError("Unsupported return value type: ", return_value_ast)
|
||||
ast.fix_missing_locations(return_ast)
|
||||
|
||||
def isolate(self):
|
||||
"""Link prev node to next node and isolate node from source code order list."""
|
||||
origin_prev: Optional[Node] = self._prev
|
||||
origin_next: Optional[Node] = self._next
|
||||
if origin_prev is not None:
|
||||
origin_prev.set_next(origin_next)
|
||||
origin_prev._next = origin_next
|
||||
if origin_next is not None:
|
||||
origin_next.set_prev(origin_prev)
|
||||
origin_next._prev = origin_prev
|
||||
self._prev = None
|
||||
self._next = None
|
||||
|
||||
|
@ -708,9 +530,9 @@ class Node:
|
|||
node.isolate()
|
||||
origin_prev: Optional[Node] = self._prev
|
||||
if origin_prev is not None:
|
||||
origin_prev.set_next(node)
|
||||
node.set_prev(origin_prev)
|
||||
node.set_next(self)
|
||||
origin_prev._next = node
|
||||
node._prev = origin_prev
|
||||
node._next = self
|
||||
self._prev = node
|
||||
|
||||
def insert_after(self, node: 'Node'):
|
||||
|
@ -723,10 +545,10 @@ class Node:
|
|||
node.isolate()
|
||||
origin_next: Optional[Node] = self._next
|
||||
self._next = node
|
||||
node.set_prev(self)
|
||||
node.set_next(origin_next)
|
||||
node._prev = self
|
||||
node._next = origin_next
|
||||
if origin_next is not None:
|
||||
origin_next.set_prev(node)
|
||||
origin_next._prev = node
|
||||
|
||||
def get_inputs(self) -> ['Node']:
|
||||
"""
|
||||
|
@ -867,15 +689,6 @@ class Node:
|
|||
"""
|
||||
return self._instance
|
||||
|
||||
def _sync_arg(self):
|
||||
"""Sync _normalized_args to corresponding ast node when updated."""
|
||||
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree):
|
||||
self._sync_call_cell_args_to_ast()
|
||||
elif self._node_type == NodeType.Output:
|
||||
self._sync_return_node_to_ast()
|
||||
elif self._node_type == NodeType.CallMethod:
|
||||
self._sync_call_method_args_to_ast()
|
||||
|
||||
def set_arg_by_node(self, arg_idx: int, node: 'Node', out_idx: Optional[int] = None):
|
||||
"""
|
||||
Set argument by another Node.
|
||||
|
@ -896,12 +709,12 @@ class Node:
|
|||
if arg_idx >= self._args_num or arg_idx < 0:
|
||||
raise RuntimeError("arg_idx out of range: ", arg_idx)
|
||||
if out_idx is None:
|
||||
if len(node.get_targets()) != 1:
|
||||
if len(node._targets) != 1:
|
||||
raise RuntimeError("node should has one output when out_idx is not provided")
|
||||
out_idx = 0
|
||||
if out_idx >= len(node.get_targets()):
|
||||
if out_idx >= len(node._targets):
|
||||
raise RuntimeError("out_idx out of range: ", out_idx)
|
||||
new_arg = node.get_targets()[out_idx]
|
||||
new_arg = node._targets[out_idx]
|
||||
self._normalized_args[self._normalized_args_keys[arg_idx]] = new_arg
|
||||
self._sync_arg()
|
||||
|
||||
|
@ -1091,26 +904,205 @@ class Node:
|
|||
"""
|
||||
return self._attribute.get(key)
|
||||
|
||||
@staticmethod
|
||||
def _get_cell_or_prim_op_attribute(obj) -> dict:
|
||||
def _get_normalized_args(self, args: [ScopedValue], kwargs: {str: ScopedValue}) -> dict:
|
||||
"""
|
||||
Find attributes of cell-op or primitive-op.
|
||||
Merge args and kwargs to normalized args.
|
||||
The keys of args are obtained from the construct function of type(self._instance).
|
||||
|
||||
Args:
|
||||
obj: A cell-op or a primitive-op.
|
||||
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
||||
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Input args are invalid.
|
||||
RuntimeError: Arg name already exist in kwargs.
|
||||
|
||||
Returns:
|
||||
A dict represents attributes of input 'obj'.
|
||||
The normalized args.
|
||||
"""
|
||||
attributes = {}
|
||||
if obj is None:
|
||||
return attributes
|
||||
for k, v in obj.__dict__.items():
|
||||
if k.startswith("_"):
|
||||
continue
|
||||
attributes[k] = v
|
||||
attributes["cls"] = obj.__class__
|
||||
return attributes
|
||||
if not args:
|
||||
args = []
|
||||
if not kwargs:
|
||||
kwargs = {}
|
||||
normalized_args: dict = dict()
|
||||
if self._instance and hasattr(type(self._instance), "construct"):
|
||||
parameters = inspect.signature(type(self._instance).construct).parameters
|
||||
names = Node._get_construct_arg_names(parameters)
|
||||
Node._map_args_names(names, args, kwargs, self._normalized_args_keys, normalized_args)
|
||||
else:
|
||||
logger.debug("fail to get arg name from op, using arg_xx for args' name")
|
||||
arg_temp_name, suffix = "arg", 0
|
||||
for arg in args:
|
||||
arg_key = "{}_{}".format(arg_temp_name, suffix)
|
||||
while arg_key in kwargs.keys() or arg_key in normalized_args.keys():
|
||||
suffix += 1
|
||||
arg_key = "{}_{}".format(arg_temp_name, suffix)
|
||||
normalized_args[arg_key] = arg
|
||||
self._normalized_args_keys.append(arg_key)
|
||||
for arg_key, value in kwargs.items():
|
||||
normalized_args[arg_key] = value
|
||||
self._normalized_args_keys.append(arg_key)
|
||||
return normalized_args
|
||||
|
||||
def _sync_assign_func_to_ast(self):
|
||||
"""Sync func of ast.Call of ast.Assign from self._name when NodeType is CallCell or CallPrimitive."""
|
||||
if self._ast_node is None:
|
||||
return
|
||||
assign_ast = self._ast_node
|
||||
if not isinstance(assign_ast, ast.Assign):
|
||||
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
|
||||
call_ast = assign_ast.value
|
||||
if not isinstance(call_ast, ast.Call):
|
||||
raise TypeError("call_ast should be ast.Call, got: ", type(call_ast))
|
||||
func_ast = call_ast.func
|
||||
if not self._func.value:
|
||||
if isinstance(func_ast, ast.Name):
|
||||
func_ast.id = self._func.value
|
||||
else:
|
||||
call_ast.func = ast.Name(self._func.value, ast.Store())
|
||||
else:
|
||||
if isinstance(func_ast, ast.Attribute):
|
||||
func_value = func_ast.value
|
||||
if not isinstance(func_value, ast.Name):
|
||||
raise RuntimeError("Only support ast.Name as value of attribute ", type(func_ast.value))
|
||||
func_value.id = self._func.scope
|
||||
func_ast.attr = self._func.value
|
||||
else:
|
||||
call_ast.func = ast.Attribute(ast.Name(self._func.scope, ast.Load()), self._func.value, ast.Store())
|
||||
ast.fix_missing_locations(assign_ast)
|
||||
|
||||
def _sync_assign_targets_to_ast(self):
|
||||
"""Sync targets of ast.Assign from self._targets when NodeType is CallCell, CallPrimitive or CallMethod."""
|
||||
if self._ast_node is None:
|
||||
return
|
||||
assign_ast = self._ast_node
|
||||
if not isinstance(assign_ast, ast.Assign):
|
||||
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
|
||||
# update targets
|
||||
targets_ast = assign_ast.targets
|
||||
if isinstance(targets_ast[0], ast.Tuple) and len(self._targets) != len(targets_ast[0].elts):
|
||||
raise RuntimeError("self._targets should have the same length as targets_ast's elts")
|
||||
if not isinstance(targets_ast[0], ast.Tuple) and len(self._targets) != len(targets_ast):
|
||||
raise RuntimeError("self._targets should have targets_ast same length")
|
||||
for i in range(0, len(self._targets)):
|
||||
target = self._targets[i]
|
||||
target_ast = targets_ast[0]
|
||||
if isinstance(target_ast, ast.Name):
|
||||
target_ast.id = target.value
|
||||
elif isinstance(target_ast, ast.Tuple):
|
||||
if not isinstance(target_ast.elts[i], ast.Name):
|
||||
raise TypeError("target should be ast.Name, got:", type(target_ast.elts[i]))
|
||||
target_ast.elts[i].id = target.value
|
||||
else:
|
||||
raise TypeError("target_ast should be ast.Name or ast.Tuple, got: ", type(target_ast))
|
||||
target_ast.id = target.value
|
||||
ast.fix_missing_locations(assign_ast)
|
||||
|
||||
def _sync_call_cell_args_to_ast(self):
|
||||
"""Sync args of ast.Cell of ast.Assign from self._normalized_args when NodeType is CallCell or CallPrimitive."""
|
||||
if self._ast_node is None:
|
||||
return
|
||||
assign_ast = self._ast_node
|
||||
if not isinstance(assign_ast, ast.Assign):
|
||||
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
|
||||
assign_value = assign_ast.value
|
||||
if not isinstance(assign_value, ast.Call):
|
||||
return
|
||||
keywords_ast = assign_value.keywords
|
||||
args_ast = assign_value.args
|
||||
if len(self._normalized_args_keys) != (len(keywords_ast) + len(args_ast)):
|
||||
raise RuntimeError("ast keywords plus args len is not equal to self._normalized_args value")
|
||||
|
||||
for arg_index in range(self._args_num):
|
||||
arg_ast = args_ast[arg_index]
|
||||
AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[arg_index]), arg_ast)
|
||||
|
||||
# the order of kwargs may not the same as that in keywords_ast
|
||||
keyword_map_index = {}
|
||||
for index, keyword_ast in enumerate(keywords_ast):
|
||||
keyword_map_index[keyword_ast.arg] = index
|
||||
for keyword_index in range(self._kwargs_num):
|
||||
key = self._normalized_args_keys[keyword_index + self._args_num]
|
||||
AstModifier.update_arg_value(self._normalized_args.get(key),
|
||||
keywords_ast[keyword_map_index.get(key)].value)
|
||||
|
||||
def _sync_call_method_args_to_ast(self):
|
||||
"""Sync args of ast.Cell of ast.Assign from self._normalized_args when NodeType is CallMethod."""
|
||||
if self._ast_node is None:
|
||||
return
|
||||
assign_ast = self._ast_node
|
||||
if not isinstance(assign_ast, ast.Assign):
|
||||
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
|
||||
assign_value = assign_ast.value
|
||||
if self._func == PASS_THROUGH_METHOD:
|
||||
if isinstance(assign_value, ast.Name):
|
||||
if len(self._normalized_args_keys) != 1:
|
||||
raise RuntimeError("self._normalized_args_keys should have 1 elements")
|
||||
arg = self._normalized_args.get(self._normalized_args_keys[0])
|
||||
if arg.type != ValueType.NamingValue:
|
||||
raise RuntimeError("arg.type should equal to ValueType.NamingValue")
|
||||
if arg.scope != "":
|
||||
raise RuntimeError("arg.scope should be empty")
|
||||
assign_value.id = arg.value
|
||||
elif isinstance(assign_value, ast.Attribute):
|
||||
if len(self._normalized_args_keys) != 1:
|
||||
raise RuntimeError("self._normalized_args_keys should have 1 elements")
|
||||
arg = self._normalized_args.get(self._normalized_args_keys[0])
|
||||
if arg.type != ValueType.NamingValue:
|
||||
raise RuntimeError("arg.type should equal to ValueType.NamingValue")
|
||||
assign_value.attr = arg.value
|
||||
assign_value_value = assign_value.value
|
||||
if not isinstance(assign_value_value, ast.Name):
|
||||
raise RuntimeError("Only support ast.Name as value of attribute ", type(assign_value_value))
|
||||
assign_value_value.id = arg.scope
|
||||
else:
|
||||
if len(self._normalized_args_keys) != 1:
|
||||
raise RuntimeError("self._normalized_args_keys should have 1 elements")
|
||||
arg = self._normalized_args.get(self._normalized_args_keys[0])
|
||||
if arg.type not in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
|
||||
raise RuntimeError("arg should be an IntValue, FloatValue or StringValue")
|
||||
if arg.scope != "":
|
||||
raise RuntimeError("arg.scope should be empty")
|
||||
assign_value.value = arg.value
|
||||
else:
|
||||
raise RuntimeError("Only support pass_through method as call_method now, ", self._func.value)
|
||||
|
||||
def _sync_return_node_to_ast(self):
|
||||
"""Sync return value of ast.Return from self._normalized_args when NodeType is Output."""
|
||||
if self._ast_node is None:
|
||||
return
|
||||
return_ast = self._ast_node
|
||||
if not isinstance(return_ast, ast.Return):
|
||||
raise TypeError("return_ast should be ast.Return, got: ", type(return_ast))
|
||||
# update args
|
||||
return_value_ast = return_ast.value
|
||||
if isinstance(return_value_ast, ast.Name):
|
||||
if len(self._normalized_args_keys) != 1:
|
||||
raise RuntimeError("self._normalized_args_keys should have 1 elements")
|
||||
return_value_ast.id = self._normalized_args.get(self._normalized_args_keys[0]).value
|
||||
elif isinstance(return_value_ast, ast.Tuple):
|
||||
elements = return_value_ast.elts
|
||||
if len(self._normalized_args.values()) != len(elements):
|
||||
raise RuntimeError("self._normalized_args.values() should have elements same length")
|
||||
for elt_index, elt in enumerate(elements):
|
||||
if not isinstance(elt, ast.Name):
|
||||
raise RuntimeError("Only support ast.Name as return value: ", elt)
|
||||
arg = self._normalized_args.get(self._normalized_args_keys[elt_index])
|
||||
if not isinstance(arg, ScopedValue):
|
||||
raise TypeError("arg should be ScopedValue, got: ", type(arg))
|
||||
elt.id = arg.value
|
||||
else:
|
||||
raise RuntimeError("Unsupported return value type: ", return_value_ast)
|
||||
ast.fix_missing_locations(return_ast)
|
||||
|
||||
def _sync_arg(self):
|
||||
"""Sync _normalized_args to corresponding ast node when updated."""
|
||||
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree):
|
||||
self._sync_call_cell_args_to_ast()
|
||||
elif self._node_type == NodeType.Output:
|
||||
self._sync_return_node_to_ast()
|
||||
elif self._node_type == NodeType.CallMethod:
|
||||
self._sync_call_method_args_to_ast()
|
||||
|
||||
|
||||
class TreeNode(Node):
|
||||
|
|
|
@ -116,13 +116,77 @@ class SymbolTree(Observer, Observable):
|
|||
self._modified = False
|
||||
self._node_visitor = None
|
||||
|
||||
@staticmethod
|
||||
def _link_nodes_and_find_root(nodes: [Node]) -> Node:
|
||||
"""
|
||||
Find inputs for all nodes created by Replacement according to their targets and arguments.
|
||||
|
||||
Find root node of all nodes created by Replacement. One and Only one root should be found.
|
||||
|
||||
Args:
|
||||
nodes (list[Node]): A list of instance of Node created by Replacement.
|
||||
|
||||
Returns:
|
||||
An instance of Node represents root of input nodes.
|
||||
"""
|
||||
consumers: [ScopedValue] = []
|
||||
target_dict: {ScopedValue: Node} = {}
|
||||
for node in nodes:
|
||||
consumers.extend(node.get_args())
|
||||
for _, arg in node.get_kwargs():
|
||||
consumers.append(arg)
|
||||
for target in node.get_targets():
|
||||
if target_dict.get(target) is not None:
|
||||
raise RuntimeError("Target of node duplicated")
|
||||
target_dict[target] = node
|
||||
# find root node
|
||||
root = None
|
||||
for node in nodes:
|
||||
used = 0
|
||||
for target in node.get_targets():
|
||||
if target in consumers:
|
||||
used += 1
|
||||
break
|
||||
if used == 0:
|
||||
if root is not None:
|
||||
raise RuntimeError("Replacement should only has one root")
|
||||
root = node
|
||||
if root is None:
|
||||
raise RuntimeError("No root node found in replacement nodes")
|
||||
# link node's input
|
||||
for node in nodes:
|
||||
inputs = []
|
||||
for _, arg in node.get_normalized_args().items():
|
||||
node_input: Node = target_dict.get(arg)
|
||||
inputs.append(node_input)
|
||||
node.set_inputs(inputs)
|
||||
return root
|
||||
|
||||
@staticmethod
|
||||
def _find_all_class_in_symboltree(stree: 'SymbolTree', seen_class: {type, str}, allow_class_name: [], replacers):
|
||||
"""Find all non-duplicated class name of SymbolTree recursively."""
|
||||
replacer = AstReplacer(stree._class_ast)
|
||||
replacers.append(replacer)
|
||||
for node in stree.nodes():
|
||||
if not isinstance(node, TreeNode):
|
||||
continue
|
||||
sub_stree: SymbolTree = node.symbol_tree
|
||||
SymbolTree._find_all_class_in_symboltree(sub_stree, seen_class, allow_class_name, replacers)
|
||||
# all modified ast.ClassDef should export to code
|
||||
if sub_stree._modified:
|
||||
allow_class_name.append(sub_stree._class_ast.name)
|
||||
continue
|
||||
# all un-modified ast.ClassDef only keep one instance
|
||||
seen_cls_name = seen_class.get(type(sub_stree.get_origin_network()))
|
||||
if seen_cls_name is not None:
|
||||
replacer.replace_all(sub_stree._class_ast.name, seen_cls_name)
|
||||
else:
|
||||
seen_class[type(sub_stree.get_origin_network())] = sub_stree._class_ast.name
|
||||
allow_class_name.append(sub_stree._class_ast.name)
|
||||
|
||||
def finish_build(self):
|
||||
self.add_event(Event.TopologicalChangeEvent)
|
||||
|
||||
def _on_change(self, event: Event):
|
||||
self._modified = True
|
||||
self.changed(event)
|
||||
|
||||
def get_ori_cls_name(self) -> str:
|
||||
"""
|
||||
Get class name of original network.
|
||||
|
@ -234,15 +298,6 @@ class SymbolTree(Observer, Observable):
|
|||
"""
|
||||
return self._head
|
||||
|
||||
def get_return_node(self):
|
||||
"""
|
||||
Getter of `_return` which represents return statement of forward method of network.
|
||||
|
||||
Returns:
|
||||
An instance of node.
|
||||
"""
|
||||
return self._return
|
||||
|
||||
def get_origin_network(self):
|
||||
"""
|
||||
Getter of `_origin_network`.
|
||||
|
@ -295,14 +350,6 @@ class SymbolTree(Observer, Observable):
|
|||
|
||||
return self._nodes.get(node_name)
|
||||
|
||||
def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
|
||||
if isinstance(node_or_name, Node):
|
||||
result = self.get_node(node_or_name.get_name())
|
||||
return result if result is node_or_name else None
|
||||
if isinstance(node_or_name, str):
|
||||
return self.get_node(node_or_name)
|
||||
return None
|
||||
|
||||
def get_node_inputs(self, node_or_name: Union[Node, str]) -> [Node]:
|
||||
"""
|
||||
Getter of inputs in topological relation of current 'node_or_name'.
|
||||
|
@ -337,6 +384,8 @@ class SymbolTree(Observer, Observable):
|
|||
if real_node is None:
|
||||
logger.info("Node(%s) is not belong to current SymbolTree", node_or_name)
|
||||
return []
|
||||
if real_node.get_node_type() == NodeType.Output:
|
||||
return []
|
||||
return self._topo_mgr.get_node_users(node_or_name)
|
||||
|
||||
def before(self, node_or_name: Union[Node, str]) -> Position:
|
||||
|
@ -628,104 +677,6 @@ class SymbolTree(Observer, Observable):
|
|||
self._node_visitor.remove_node(node)
|
||||
return node
|
||||
|
||||
def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node:
|
||||
"""
|
||||
Insert a node-tree into SymbolTree.
|
||||
Note:
|
||||
Inputs of intra sub-tree nodes need to be welly set.
|
||||
|
||||
Inputs of inter sub-tree nodes will be updated by Rewrite automatically.
|
||||
|
||||
Args:
|
||||
position (Position): A Position indicates an insert position point.
|
||||
root (Node): An instance of node as root of node-tree to be inserted in.
|
||||
insert_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
|
||||
True.
|
||||
|
||||
Returns:
|
||||
An instance of node as root node of node-tree which has been inserted into SymbolTree.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If 'position' is not in current SymbolTree.
|
||||
"""
|
||||
|
||||
# if position not in current SymbolTree
|
||||
if position.symbol_tree is not self:
|
||||
raise RuntimeError("Position is not in current SymbolTree: ", position)
|
||||
|
||||
queue: [Node] = [root]
|
||||
todos: [] = []
|
||||
inputs_list: [] = []
|
||||
while queue:
|
||||
cur_node = queue.pop(0)
|
||||
if cur_node in todos:
|
||||
continue
|
||||
todos.append(cur_node)
|
||||
node_inputs = cur_node.get_inputs()
|
||||
inputs_list.append(node_inputs)
|
||||
for node_input in node_inputs:
|
||||
if node_input is not None:
|
||||
queue.append(node_input)
|
||||
todos.reverse()
|
||||
inputs_list.reverse()
|
||||
for index, todo in enumerate(todos):
|
||||
self.insert_node(position, todo, insert_to_ast)
|
||||
position = self.after(todo)
|
||||
# relink input of node
|
||||
original_inputs = inputs_list[index]
|
||||
for arg_idx, original_input in enumerate(original_inputs):
|
||||
if original_input is not None:
|
||||
self.set_node_arg_by_node(todo, arg_idx, original_input)
|
||||
return root
|
||||
|
||||
@staticmethod
|
||||
def _link_nodes_and_find_root(nodes: [Node]) -> Node:
|
||||
"""
|
||||
Find inputs for all nodes created by Replacement according to their targets and arguments.
|
||||
|
||||
Find root node of all nodes created by Replacement. One and Only one root should be found.
|
||||
|
||||
Args:
|
||||
nodes (list[Node]): A list of instance of Node created by Replacement.
|
||||
|
||||
Returns:
|
||||
An instance of Node represents root of input nodes.
|
||||
"""
|
||||
consumers: [ScopedValue] = []
|
||||
target_dict: {ScopedValue: Node} = {}
|
||||
for node in nodes:
|
||||
consumers.extend(node.get_args())
|
||||
for _, arg in node.get_kwargs():
|
||||
consumers.append(arg)
|
||||
for target in node.get_targets():
|
||||
if target_dict.get(target) is not None:
|
||||
raise RuntimeError("Target of node duplicated")
|
||||
target_dict[target] = node
|
||||
# find root node
|
||||
root = None
|
||||
for node in nodes:
|
||||
used = 0
|
||||
for target in node.get_targets():
|
||||
if target in consumers:
|
||||
used += 1
|
||||
if used == 0:
|
||||
if root is not None:
|
||||
raise RuntimeError("Replacement should only has one root")
|
||||
root = node
|
||||
if root is None:
|
||||
raise RuntimeError("No root node found in replacement nodes")
|
||||
# link node's input
|
||||
for node in nodes:
|
||||
inputs = []
|
||||
for _, arg in node.get_normalized_args().items():
|
||||
node_input: Node = target_dict.get(arg)
|
||||
if node_input is None:
|
||||
inputs.append(None)
|
||||
else:
|
||||
inputs.append(node_input)
|
||||
node.set_inputs(inputs)
|
||||
return root
|
||||
|
||||
def replace(self, old_node: Node, new_nodes: [Node]) -> Node:
|
||||
"""
|
||||
Replace an old_node with a node_tree. 'new_node' is the root node of the node_tree.
|
||||
|
@ -834,28 +785,6 @@ class SymbolTree(Observer, Observable):
|
|||
dump_st = SymbolTreeDumper(self)
|
||||
dump_st.dump()
|
||||
|
||||
@staticmethod
|
||||
def _find_all_class_in_symboltree(stree: 'SymbolTree', seen_class: {type, str}, allow_class_name: [], replacers):
|
||||
"""Find all non-duplicated class name of SymbolTree recursively."""
|
||||
replacer = AstReplacer(stree._class_ast)
|
||||
replacers.append(replacer)
|
||||
for node in stree.nodes():
|
||||
if not isinstance(node, TreeNode):
|
||||
continue
|
||||
sub_stree: SymbolTree = node.symbol_tree
|
||||
SymbolTree._find_all_class_in_symboltree(sub_stree, seen_class, allow_class_name, replacers)
|
||||
# all modified ast.ClassDef should export to code
|
||||
if sub_stree._modified:
|
||||
allow_class_name.append(sub_stree._class_ast.name)
|
||||
continue
|
||||
# all un-modified ast.ClassDef only keep one instance
|
||||
seen_cls_name = seen_class.get(type(sub_stree.get_origin_network()))
|
||||
if seen_cls_name is not None:
|
||||
replacer.replace_all(sub_stree._class_ast.name, seen_cls_name)
|
||||
else:
|
||||
seen_class[type(sub_stree.get_origin_network())] = sub_stree._class_ast.name
|
||||
allow_class_name.append(sub_stree._class_ast.name)
|
||||
|
||||
def get_code(self) -> str:
|
||||
"""
|
||||
Get source code of modified network.
|
||||
|
@ -897,6 +826,64 @@ class SymbolTree(Observer, Observable):
|
|||
cls = self._get_cls_through_file()
|
||||
return cls(self._global_vars)
|
||||
|
||||
def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
|
||||
if isinstance(node_or_name, Node):
|
||||
result = self.get_node(node_or_name.get_name())
|
||||
return result if result is node_or_name else None
|
||||
if isinstance(node_or_name, str):
|
||||
return self.get_node(node_or_name)
|
||||
return None
|
||||
|
||||
def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node:
|
||||
"""
|
||||
Insert a node-tree into SymbolTree.
|
||||
Note:
|
||||
Inputs of intra sub-tree nodes need to be welly set.
|
||||
|
||||
Inputs of inter sub-tree nodes will be updated by Rewrite automatically.
|
||||
|
||||
Args:
|
||||
position (Position): A Position indicates an insert position point.
|
||||
root (Node): An instance of node as root of node-tree to be inserted in.
|
||||
insert_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
|
||||
True.
|
||||
|
||||
Returns:
|
||||
An instance of node as root node of node-tree which has been inserted into SymbolTree.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If 'position' is not in current SymbolTree.
|
||||
"""
|
||||
|
||||
# if position not in current SymbolTree
|
||||
if position.symbol_tree is not self:
|
||||
raise RuntimeError("Position is not in current SymbolTree: ", position)
|
||||
|
||||
queue: [Node] = [root]
|
||||
todos: [] = []
|
||||
inputs_list: [] = []
|
||||
while queue:
|
||||
cur_node = queue.pop(0)
|
||||
if cur_node in todos:
|
||||
continue
|
||||
todos.append(cur_node)
|
||||
node_inputs = cur_node.get_inputs()
|
||||
inputs_list.append(node_inputs)
|
||||
for node_input in node_inputs:
|
||||
if node_input is not None:
|
||||
queue.append(node_input)
|
||||
todos.reverse()
|
||||
inputs_list.reverse()
|
||||
for index, todo in enumerate(todos):
|
||||
self.insert_node(position, todo, insert_to_ast)
|
||||
position = self.after(todo)
|
||||
# relink input of node
|
||||
original_inputs = inputs_list[index]
|
||||
for arg_idx, original_input in enumerate(original_inputs):
|
||||
if original_input is not None:
|
||||
self.set_node_arg_by_node(todo, arg_idx, original_input)
|
||||
return root
|
||||
|
||||
def _unique_targets(self, node: Node):
|
||||
"""
|
||||
Unique targets of node by _target_namer.
|
||||
|
@ -1025,3 +1012,7 @@ class SymbolTree(Observer, Observable):
|
|||
if network_cls is None:
|
||||
raise RuntimeError("Can not find network class:", self._opt_cls_name)
|
||||
return network_cls
|
||||
|
||||
def _on_change(self, event: Event):
|
||||
self._modified = True
|
||||
self.changed(event)
|
||||
|
|
Loading…
Reference in New Issue