!33747 rewrite subgraph
Merge pull request !33747 from 于振华/rewrite_subgraph01
This commit is contained in:
commit
ad304d342e
|
@ -88,8 +88,8 @@ class Node:
|
|||
RuntimeError: If value of kwarg in `kwargs` is not a `NamingValue`-`ScopedValue` or a
|
||||
`CustomObjValue`-`ScopedValue`.
|
||||
"""
|
||||
return Node(NodeImpl.create_call_buildin_op(cell, None, targets, ScopedValue.create_naming_value(name, "self"),
|
||||
args, kwargs, name))
|
||||
return Node(NodeImpl.create_call_op(cell, None, targets, ScopedValue.create_naming_value(name, "self"),
|
||||
args, kwargs, name))
|
||||
|
||||
def get_prev(self) -> 'Node':
|
||||
"""
|
||||
|
|
|
@ -51,14 +51,15 @@ class SymbolTree:
|
|||
"""
|
||||
return self._symbol_tree
|
||||
|
||||
def nodes(self) -> {}:
|
||||
def nodes(self):
|
||||
"""
|
||||
Get all nodes of corresponding network.
|
||||
|
||||
Returns:
|
||||
A dict mapping from name of node to node.
|
||||
"""
|
||||
return [Node(node_impl) for node_impl in self._symbol_tree.nodes(unfold_subtree=False)]
|
||||
for node in self._symbol_tree.nodes():
|
||||
yield Node(node)
|
||||
|
||||
def get_node(self, node_name: str) -> Optional[Node]:
|
||||
"""
|
||||
|
|
|
@ -19,7 +19,6 @@ from .symbol_tree import SymbolTree
|
|||
from .node import Node
|
||||
from .node_type import NodeType
|
||||
from ..symbol_tree import SymbolTree as SymbolTreeImpl
|
||||
from ..node import TreeNode
|
||||
|
||||
|
||||
class TreeNodeHelper:
|
||||
|
@ -47,7 +46,6 @@ class TreeNodeHelper:
|
|||
|
||||
if node.get_node_type() == NodeType.Tree:
|
||||
node_impl = node.get_handler()
|
||||
assert isinstance(node_impl, TreeNode)
|
||||
subtree: SymbolTreeImpl = node_impl.symbol_tree
|
||||
if subtree is None:
|
||||
return None
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Define the namespace of MindSpore op definition."""
|
||||
from .._extends.parse.namespace import CellNamespace
|
||||
|
||||
|
||||
_ms_common_ns = CellNamespace('mindspore.common')
|
||||
_ms_nn_ns = CellNamespace('mindspore.nn')
|
||||
_ms_ops_ns = CellNamespace('mindspore.ops')
|
||||
|
||||
|
||||
def is_subtree(cls_name):
|
||||
"""Determine whether 'cls_name' is a subtree."""
|
||||
if cls_name == "SequentialCell":
|
||||
return True
|
||||
if cls_name in _ms_common_ns or cls_name in _ms_nn_ns or cls_name in _ms_ops_ns:
|
||||
return False
|
||||
|
||||
return True
|
|
@ -23,6 +23,8 @@ from mindspore import log as logger
|
|||
from .ast_helpers import AstModifier
|
||||
from .api.scoped_value import ScopedValue, ValueType
|
||||
from .api.node_type import NodeType
|
||||
from .namespace import is_subtree
|
||||
from .ast_helpers.ast_replacer import AstReplacer
|
||||
|
||||
PASS_THROUGH_METHOD = ScopedValue.create_naming_value("PassThrough")
|
||||
|
||||
|
@ -223,6 +225,41 @@ class Node:
|
|||
return cls(NodeType.Output, ast_node, None, ScopedValue.create_naming_value("return"), real_return_values, {},
|
||||
name, None)
|
||||
|
||||
@staticmethod
|
||||
def create_call_op(op: Union[Cell, Primitive], ast_node: Optional[ast.AST],
|
||||
targets: [Union[ScopedValue, str]], func: Union[ScopedValue, str],
|
||||
args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None, name: str = ""):
|
||||
"""
|
||||
Static method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`.
|
||||
If op is custom defined, it is treated by TreeNode.
|
||||
A `CallCell` node represents an invoking to cell-op.
|
||||
A `CallPrimitive` node represents an invoking to primitive-op.
|
||||
|
||||
Args:
|
||||
op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node.
|
||||
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
|
||||
targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
||||
func ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
|
||||
args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
||||
kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
|
||||
class.
|
||||
name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`.
|
||||
Name of node also used as field name in network class.
|
||||
"""
|
||||
cls_name = type(op).__name__
|
||||
|
||||
if is_subtree(cls_name):
|
||||
from .symbol_tree_builder import SymbolTreeBuilder
|
||||
stb = SymbolTreeBuilder(op)
|
||||
stree = stb.build()
|
||||
replacer = AstReplacer(stree.get_class_ast())
|
||||
replacer.replace_all(stree.get_ori_cls_name(), stree.get_opt_cls_name())
|
||||
return TreeNode.create_tree_node(stree, None, targets, ScopedValue.create_naming_value(name, "self"),
|
||||
args, kwargs, name, op)
|
||||
|
||||
return Node.create_call_buildin_op(op, None, targets, ScopedValue.create_naming_value(name, "self"),
|
||||
args, kwargs, name)
|
||||
|
||||
@staticmethod
|
||||
def _get_construct_arg_names(parameters):
|
||||
"""
|
||||
|
@ -440,6 +477,10 @@ class Node:
|
|||
"""
|
||||
return self._prev
|
||||
|
||||
def set_prev(self, prev):
|
||||
"""Set previous node of current node in source code order. """
|
||||
self._prev = prev
|
||||
|
||||
def get_next(self) -> 'Node':
|
||||
"""
|
||||
Get next node of current node in source code order.
|
||||
|
@ -449,6 +490,10 @@ class Node:
|
|||
"""
|
||||
return self._next
|
||||
|
||||
def set_next(self, _next):
|
||||
"""Set next node of current node in source code order."""
|
||||
self._next = _next
|
||||
|
||||
def has_same_ast(self, node: Union['Node', ast.AST]) -> bool:
|
||||
"""
|
||||
Check if other node holds same ast node with self.
|
||||
|
@ -460,7 +505,7 @@ class Node:
|
|||
A bool.
|
||||
"""
|
||||
if isinstance(node, Node):
|
||||
return self.has_same_ast(node._ast_node)
|
||||
return self.has_same_ast(node.get_ast())
|
||||
if isinstance(node, ast.AST):
|
||||
return id(self._ast_node) == id(node)
|
||||
return False
|
||||
|
@ -570,7 +615,8 @@ class Node:
|
|||
keyword_map_index[keyword_ast.arg] = index
|
||||
for keyword_index in range(self._kwargs_num):
|
||||
key = self._normalized_args_keys[keyword_index + self._args_num]
|
||||
AstModifier.update_arg_value(self._normalized_args.get(key), keywords_ast[keyword_map_index[key]].value)
|
||||
AstModifier.update_arg_value(self._normalized_args.get(key),
|
||||
keywords_ast[keyword_map_index.get(key)].value)
|
||||
|
||||
def _sync_call_method_args_to_ast(self):
|
||||
"""Sync args of ast.Cell of ast.Assign from self._normalized_args when NodeType is CallMethod."""
|
||||
|
@ -646,9 +692,9 @@ class Node:
|
|||
origin_prev: Optional[Node] = self._prev
|
||||
origin_next: Optional[Node] = self._next
|
||||
if origin_prev is not None:
|
||||
origin_prev._next = origin_next
|
||||
origin_prev.set_next(origin_next)
|
||||
if origin_next is not None:
|
||||
origin_next._prev = origin_prev
|
||||
origin_next.set_prev(origin_prev)
|
||||
self._prev = None
|
||||
self._next = None
|
||||
|
||||
|
@ -662,9 +708,9 @@ class Node:
|
|||
node.isolate()
|
||||
origin_prev: Optional[Node] = self._prev
|
||||
if origin_prev is not None:
|
||||
origin_prev._next = node
|
||||
node._prev = origin_prev
|
||||
node._next = self
|
||||
origin_prev.set_next(node)
|
||||
node.set_prev(origin_prev)
|
||||
node.set_next(self)
|
||||
self._prev = node
|
||||
|
||||
def insert_after(self, node: 'Node'):
|
||||
|
@ -677,10 +723,10 @@ class Node:
|
|||
node.isolate()
|
||||
origin_next: Optional[Node] = self._next
|
||||
self._next = node
|
||||
node._prev = self
|
||||
node._next = origin_next
|
||||
node.set_prev(self)
|
||||
node.set_next(origin_next)
|
||||
if origin_next is not None:
|
||||
origin_next._prev = node
|
||||
origin_next.set_prev(node)
|
||||
|
||||
def get_inputs(self) -> ['Node']:
|
||||
"""
|
||||
|
@ -850,12 +896,12 @@ class Node:
|
|||
if arg_idx >= self._args_num or arg_idx < 0:
|
||||
raise RuntimeError("arg_idx out of range: ", arg_idx)
|
||||
if out_idx is None:
|
||||
if len(node._targets) != 1:
|
||||
if len(node.get_targets()) != 1:
|
||||
raise RuntimeError("node should has one output when out_idx is not provided")
|
||||
out_idx = 0
|
||||
if out_idx >= len(node._targets):
|
||||
if out_idx >= len(node.get_targets()):
|
||||
raise RuntimeError("out_idx out of range: ", out_idx)
|
||||
new_arg = node._targets[out_idx]
|
||||
new_arg = node.get_targets()[out_idx]
|
||||
self._normalized_args[self._normalized_args_keys[arg_idx]] = new_arg
|
||||
self._sync_arg()
|
||||
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Visit nods of SymbolTree."""
|
||||
|
||||
|
||||
class NodeVisitor:
|
||||
"""Iterator class to access SymbolTree nodes"""
|
||||
def __init__(self, stree):
|
||||
self._stree = stree
|
||||
self._nodes = []
|
||||
self._index = 0
|
||||
|
||||
def __iter__(self):
|
||||
self._nodes = list(self._stree.get_nodes_dict().values())
|
||||
self._index = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._index < len(self._nodes):
|
||||
node = self._nodes[self._index]
|
||||
self._index += 1
|
||||
return node
|
||||
|
||||
raise StopIteration
|
||||
|
||||
def append_node(self, node):
|
||||
"""append new node to iterator"""
|
||||
self._nodes.append(node)
|
||||
|
||||
def remove_node(self, node):
|
||||
"""remove node of iterator"""
|
||||
self._nodes.remove(node)
|
|
@ -186,7 +186,6 @@ class AssignParser(Parser):
|
|||
return results
|
||||
|
||||
def _is_subtree_cell(self, cell: Cell) -> bool:
|
||||
assert isinstance(cell, Cell)
|
||||
return not type(cell).__name__ in self._cell_namespce
|
||||
|
||||
@staticmethod
|
||||
|
@ -353,7 +352,6 @@ class AssignParser(Parser):
|
|||
# self._subnet = SubNet1(global_vars.get("subnet_args"))
|
||||
# so a change in sub-network should also be identified as a change in main-network.
|
||||
# so main-network should observe sub-network
|
||||
new_stree.reg_observer(stree)
|
||||
replacer = AstReplacer(new_stree.get_class_ast())
|
||||
replacer.replace_all(new_stree.get_ori_cls_name(), new_stree.get_opt_cls_name())
|
||||
return TreeNode(new_stree, father_ast_node, targets, func, call_args, call_kwargs, func_name,
|
||||
|
|
|
@ -114,7 +114,6 @@ class ClassDefParser(Parser):
|
|||
|
||||
def _is_subtree_field(self, ori_net, field) -> bool:
|
||||
op = getattr(ori_net, field)
|
||||
assert op is not None
|
||||
return not type(op).__name__ in self._cell_namespace
|
||||
|
||||
def _process_init_func_ast(self, stree: SymbolTree, init_ast: ast.FunctionDef):
|
||||
|
|
|
@ -32,6 +32,7 @@ from .namer import TargetNamer, NodeNamer, ClassNamer
|
|||
from .common.observer import Observer
|
||||
from .common.observable import Observable
|
||||
from .common.event import Event
|
||||
from .node_visitor import NodeVisitor
|
||||
|
||||
|
||||
class Position:
|
||||
|
@ -113,6 +114,7 @@ class SymbolTree(Observer, Observable):
|
|||
self._return: Optional[Node] = None
|
||||
|
||||
self._modified = False
|
||||
self._node_visitor = None
|
||||
|
||||
def finish_build(self):
|
||||
self.add_event(Event.TopologicalChangeEvent)
|
||||
|
@ -258,25 +260,27 @@ class SymbolTree(Observer, Observable):
|
|||
raise RuntimeError("Key of global_vars duplicated:", key)
|
||||
self._global_vars[key] = value
|
||||
|
||||
def nodes(self, unfold_subtree=False):
|
||||
def get_nodes_dict(self):
|
||||
"""Get dict of nodes"""
|
||||
return self._nodes
|
||||
|
||||
def nodes(self):
|
||||
"""
|
||||
Getter of nodes if current SymbolTree.
|
||||
|
||||
Args:
|
||||
unfold_subtree (bool): Need to iterate into sub-symbol-tree recursively.
|
||||
|
||||
Returns:
|
||||
A list of instance of Nodes.
|
||||
"""
|
||||
if unfold_subtree:
|
||||
nodes = []
|
||||
for _, v in self._nodes.items():
|
||||
if isinstance(v, TreeNode):
|
||||
nodes.extend(v.symbol_tree.nodes(unfold_subtree))
|
||||
else:
|
||||
nodes.append(v)
|
||||
return nodes
|
||||
return self._nodes.values()
|
||||
if self._node_visitor is None:
|
||||
self._node_visitor = NodeVisitor(self)
|
||||
it = iter(self._node_visitor)
|
||||
|
||||
while True:
|
||||
try:
|
||||
n = next(it)
|
||||
yield n
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def get_node(self, node_name: str) -> Optional[Node]:
|
||||
"""
|
||||
|
@ -431,19 +435,44 @@ class SymbolTree(Observer, Observable):
|
|||
# _unique_targets must called after _update_args_for_unique and _update_kwargs_for_unique
|
||||
self._unique_targets(node)
|
||||
self._insert_node(position, node)
|
||||
if isinstance(node, TreeNode):
|
||||
node.symbol_tree.reg_observer(self)
|
||||
if self._node_visitor:
|
||||
self._node_visitor.append_node(node)
|
||||
# update init-function-ast and construct-function-ast
|
||||
if insert_to_ast:
|
||||
node.set_func(ScopedValue.create_naming_value(node_name, "self"))
|
||||
node_ast = node.get_ast()
|
||||
if not isinstance(node_ast, ast.Assign):
|
||||
raise RuntimeError("Only support insert cell op now")
|
||||
AstModifier.insert_assign_to_function(self._init_func_ast,
|
||||
targets=[ScopedValue(ValueType.NamingValue, "self", node_name)],
|
||||
expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"),
|
||||
args=[ScopedValue(ValueType.StringValue, "", node_name)])
|
||||
AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
|
||||
None if position is None else position.node.get_ast(),
|
||||
position.before_node)
|
||||
if isinstance(node, TreeNode):
|
||||
global_vars_key = node.get_name() + "_args"
|
||||
self.add_global_vars(global_vars_key, node.symbol_tree.get_global_vars())
|
||||
args_call = AstModifier.create_call(ScopedValue.create_naming_value("get", "global_vars"),
|
||||
[ScopedValue.create_variable_value(global_vars_key)])
|
||||
value = ast.Call(func=ast.Name(node.symbol_tree.get_opt_cls_name(), ast.Store(), lineno=0,
|
||||
col_offset=0), args=[args_call], keywords=[], lineno=0, col_offset=0)
|
||||
|
||||
ast_target = ast.Name("self." + node.get_name(), ast.Store(), lineno=0, col_offset=0)
|
||||
assign = ast.Assign(targets=[ast_target], value=value, lineno=0, col_offset=0)
|
||||
AstModifier.insert_assign_ast_to_function(self._init_func_ast, assign)
|
||||
|
||||
assign_construct = AstModifier.create_call_assign(node.get_targets(), ScopedValue.create_naming_value
|
||||
(node.get_name(), "self"), node.get_args(), {})
|
||||
AstModifier.insert_assign_ast_to_function(self._root_ast, assign_construct,
|
||||
None if position is None else position.node.get_ast(),
|
||||
position.before_node)
|
||||
sub_stree: SymbolTree = node.symbol_tree
|
||||
from .symbol_tree_builder import SymbolTreeBuilder
|
||||
SymbolTreeBuilder.merge_module_of_subtree(self, sub_stree)
|
||||
else:
|
||||
node_ast = node.get_ast()
|
||||
if not isinstance(node_ast, ast.Assign):
|
||||
raise RuntimeError("Only support insert cell op now")
|
||||
AstModifier.insert_assign_to_function(self._init_func_ast,
|
||||
targets=[ScopedValue(ValueType.NamingValue, "self", node_name)],
|
||||
expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"),
|
||||
args=[ScopedValue(ValueType.StringValue, "", node_name)])
|
||||
AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
|
||||
None if position is None else position.node.get_ast(),
|
||||
position.before_node)
|
||||
self._global_vars[node_name] = node.get_instance()
|
||||
return node
|
||||
|
||||
|
@ -595,6 +624,8 @@ class SymbolTree(Observer, Observable):
|
|||
value.isolate()
|
||||
break
|
||||
self._topo_mgr.on_erase_node(node)
|
||||
if self._node_visitor:
|
||||
self._node_visitor.remove_node(node)
|
||||
return node
|
||||
|
||||
def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node:
|
||||
|
|
|
@ -44,6 +44,22 @@ class SymbolTreeBuilder:
|
|||
self._ast_root: ast.Module = ast.parse(network_str)
|
||||
self._root_tree: Optional[SymbolTree] = None
|
||||
|
||||
@staticmethod
|
||||
def merge_module_of_subtree(main_tree: SymbolTree, sub_stree: SymbolTree):
|
||||
"""
|
||||
Merge ast.Module of sub-network into ast.Module of main-network.
|
||||
|
||||
1. Merge imports of ast.Module.
|
||||
2. Merge classes of ast.Module.
|
||||
3. Use merged ast.Module as module of main-network and sub-network.
|
||||
"""
|
||||
|
||||
father_mod = main_tree.get_module_ast()
|
||||
sub_mod = sub_stree.get_module_ast()
|
||||
SymbolTreeBuilder._merge_import_of_module(father_mod, sub_mod)
|
||||
SymbolTreeBuilder._merge_class_of_module(father_mod, sub_mod)
|
||||
sub_stree.set_module_ast(father_mod)
|
||||
|
||||
@staticmethod
|
||||
def _ast_transform(ast_root: ast.AST) -> ast.AST:
|
||||
"""
|
||||
|
@ -80,7 +96,6 @@ class SymbolTreeBuilder:
|
|||
main_mod_finder = AstFinder(main_mod)
|
||||
imports_in_sub = copy(sub_mod_finder.find_all((ast.Import, ast.ImportFrom)))
|
||||
imports_in_main = copy(main_mod_finder.find_all((ast.Import, ast.ImportFrom)))
|
||||
assert imports_in_main
|
||||
first_import = imports_in_main[0]
|
||||
for clazz in imports_in_sub:
|
||||
AstModifier.insert_sub_ast(main_mod, clazz, first_import, True)
|
||||
|
@ -103,12 +118,11 @@ class SymbolTreeBuilder:
|
|||
main_mod_finder = AstFinder(main_mod)
|
||||
classes_in_sub = copy(sub_mod_finder.find_all(ast.ClassDef))
|
||||
classes_in_main = copy(main_mod_finder.find_all(ast.ClassDef))
|
||||
assert classes_in_main
|
||||
first_class = classes_in_main[0]
|
||||
for clazz in classes_in_sub:
|
||||
AstModifier.insert_class_into_module(main_mod, clazz, first_class, True)
|
||||
|
||||
def _merge_module_of_subtree(self):
|
||||
def _merge_module_of_subtrees(self):
|
||||
"""
|
||||
Merge ast.Module of all sub-networks into ast.Module of main-network.
|
||||
|
||||
|
@ -117,13 +131,9 @@ class SymbolTreeBuilder:
|
|||
3. Use merged ast.Module as module of main-network and sub-network.
|
||||
"""
|
||||
|
||||
father_mod = self._root_tree.get_module_ast()
|
||||
for node in self._root_tree.nodes():
|
||||
if isinstance(node, TreeNode):
|
||||
sub_stree: SymbolTree = node.symbol_tree
|
||||
SymbolTreeBuilder._merge_import_of_module(father_mod, sub_stree.get_module_ast())
|
||||
SymbolTreeBuilder._merge_class_of_module(father_mod, sub_stree.get_module_ast())
|
||||
sub_stree.set_module_ast(father_mod)
|
||||
SymbolTreeBuilder.merge_module_of_subtree(self._root_tree, node.symbol_tree)
|
||||
|
||||
def _reduce_redundant_import(self):
|
||||
"""
|
||||
|
@ -140,7 +150,6 @@ class SymbolTreeBuilder:
|
|||
if isinstance(body, ast.Import):
|
||||
names = body.names
|
||||
for name in names:
|
||||
assert isinstance(name, ast.alias)
|
||||
import_hash = hash((name.name, name.asname))
|
||||
if import_hash in exist_import:
|
||||
continue
|
||||
|
@ -150,7 +159,6 @@ class SymbolTreeBuilder:
|
|||
import_module = body.module
|
||||
names = body.names
|
||||
for name in names:
|
||||
assert isinstance(name, ast.alias)
|
||||
import_hash = hash((import_module, name.name, name.asname))
|
||||
if import_hash in exist_import_from:
|
||||
continue
|
||||
|
@ -182,7 +190,7 @@ class SymbolTreeBuilder:
|
|||
self._root_tree: SymbolTree = SymbolTree(self._origin_net, self._ast_root)
|
||||
parser: Parser = ParserRegister.instance().get_parser(ast.Module)
|
||||
parser.process(self._root_tree, self._ast_root)
|
||||
self._merge_module_of_subtree()
|
||||
self._merge_module_of_subtrees()
|
||||
self._reduce_redundant_import()
|
||||
ast.fix_missing_locations(self._root_tree.get_module_ast())
|
||||
self._root_tree.finish_build()
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
from mindspore.nn import Cell, Conv2d
|
||||
from mindspore.rewrite import SymbolTree
|
||||
from mindspore.ops import operations as P
|
||||
from .utils import get_node_by_index
|
||||
|
||||
|
||||
class SubNet(Cell):
|
||||
|
@ -56,7 +57,8 @@ def test_multi_targets():
|
|||
"""
|
||||
test_cls = NetMultiTargets()
|
||||
stree = SymbolTree.create(test_cls)
|
||||
node = stree.nodes()[2]
|
||||
node = get_node_by_index(stree, 2)
|
||||
assert node is not None
|
||||
targets = node.get_targets()
|
||||
assert targets[0].value == 'c1'
|
||||
assert targets[1].value == 'c2'
|
||||
|
|
|
@ -121,7 +121,6 @@ def erase_node_x_11(stree: SymbolTree):
|
|||
|
||||
def transform(stree: SymbolTree):
|
||||
add_conv_before_flatten(stree)
|
||||
add_my_cell_after_x_12(stree)
|
||||
erase_node_x_11(stree)
|
||||
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ from mindspore.nn import Cell, Conv2d, BatchNorm2d, ReLU
|
|||
from mindspore.ops import Add, AddN
|
||||
from mindspore.rewrite import ScopedValue, Node, SymbolTree
|
||||
from mindspore.rewrite import PatternEngine, PatternNode, Replacement, VarNode
|
||||
from .utils import get_symbol_tree_nodes_count
|
||||
|
||||
|
||||
def test_tree_pattern_match():
|
||||
|
@ -92,13 +93,13 @@ def test_one_to_one_pattern():
|
|||
assert bn is not None
|
||||
assert relu1 is not None
|
||||
assert len(construct_ast.body) == 6
|
||||
assert len(stree.nodes()) == 7
|
||||
assert get_symbol_tree_nodes_count(stree) == 7
|
||||
|
||||
bn_replace = BnReplace()
|
||||
bn_replace.apply(stree)
|
||||
|
||||
assert len(construct_ast.body) == 6
|
||||
assert len(stree.nodes()) == 7
|
||||
assert get_symbol_tree_nodes_count(stree) == 7
|
||||
conv = stree.get_node("conv")
|
||||
bn = stree.get_node("bn")
|
||||
relu1 = stree.get_node("relu1")
|
||||
|
@ -167,13 +168,13 @@ def test_one_to_multi_chain_pattern():
|
|||
assert bn is not None
|
||||
assert relu1 is not None
|
||||
assert len(construct_ast.body) == 6
|
||||
assert len(stree.nodes()) == 7
|
||||
assert get_symbol_tree_nodes_count(stree) == 7
|
||||
|
||||
bn_replace = BnReplace()
|
||||
bn_replace.apply(stree)
|
||||
|
||||
assert len(construct_ast.body) == 7
|
||||
assert len(stree.nodes()) == 8
|
||||
assert get_symbol_tree_nodes_count(stree) == 8
|
||||
conv = stree.get_node("conv")
|
||||
bn = stree.get_node("bn")
|
||||
relu1 = stree.get_node("relu1")
|
||||
|
@ -296,13 +297,13 @@ def test_tree_pattern():
|
|||
assert relu2 is not None
|
||||
construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
|
||||
assert len(construct_ast.body) == 8
|
||||
assert len(stree.nodes()) == 9
|
||||
assert get_symbol_tree_nodes_count(stree) == 9
|
||||
|
||||
add_relu_pattern = AddReluPattern()
|
||||
add_relu_pattern.apply(stree)
|
||||
|
||||
assert len(construct_ast.body) == 10
|
||||
assert len(stree.nodes()) == 11
|
||||
assert get_symbol_tree_nodes_count(stree) == 11
|
||||
conv1 = stree.get_node("conv1")
|
||||
conv2 = stree.get_node("conv2")
|
||||
add = stree.get_node("add")
|
||||
|
@ -481,13 +482,13 @@ def test_multi_input_to_multi_pattern_tree_pattern():
|
|||
assert relu is not None
|
||||
construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
|
||||
assert len(construct_ast.body) == 6
|
||||
assert len(stree.nodes()) == 9
|
||||
assert get_symbol_tree_nodes_count(stree) == 9
|
||||
|
||||
multi_input_pattern = MultiInputPattern()
|
||||
multi_input_pattern.apply(stree)
|
||||
|
||||
assert len(construct_ast.body) == 4
|
||||
assert len(stree.nodes()) == 7
|
||||
assert get_symbol_tree_nodes_count(stree) == 7
|
||||
conv1 = stree.get_node("conv1")
|
||||
conv2 = stree.get_node("conv2")
|
||||
add1 = stree.get_node("add1")
|
||||
|
@ -598,13 +599,13 @@ def test_one_input_to_multi_pattern_tree_pattern():
|
|||
assert relu is not None
|
||||
construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
|
||||
assert len(construct_ast.body) == 6
|
||||
assert len(stree.nodes()) == 7
|
||||
assert get_symbol_tree_nodes_count(stree) == 7
|
||||
|
||||
multi_input_pattern = MultiInputPattern()
|
||||
multi_input_pattern.apply(stree)
|
||||
|
||||
assert len(construct_ast.body) == 4
|
||||
assert len(stree.nodes()) == 5
|
||||
assert get_symbol_tree_nodes_count(stree) == 5
|
||||
conv1 = stree.get_node("conv1")
|
||||
conv2 = stree.get_node("conv2")
|
||||
add1 = stree.get_node("add1")
|
||||
|
|
|
@ -100,10 +100,52 @@ def erase_relu_in_conv2(stree: SymbolTree):
|
|||
break
|
||||
|
||||
|
||||
def inset_subtree(stree: SymbolTree):
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "conv2":
|
||||
position = stree.before(node)
|
||||
subtree = SubNet()
|
||||
new_node = Node.create_call_cell(subtree, targets=[ScopedValue.create_naming_value('x')], name='conv',
|
||||
args=[ScopedValue.create_naming_value('x')], kwargs={})
|
||||
stree.insert(position, new_node)
|
||||
break
|
||||
|
||||
|
||||
def inset_subtree2(stree: SymbolTree):
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "conv2":
|
||||
position = stree.before(node)
|
||||
subtree = SubNet()
|
||||
new_node = Node.create_call_cell(subtree, targets=[ScopedValue.create_naming_value('x')], name='conv11',
|
||||
args=[ScopedValue.create_naming_value('x')], kwargs={})
|
||||
stree.insert(position, new_node)
|
||||
break
|
||||
|
||||
|
||||
def add_relu_in_conv11(stree: SymbolTree):
|
||||
for node in stree.nodes():
|
||||
if node.get_node_type() != NodeType.Tree:
|
||||
continue
|
||||
if node.get_name() == "conv11":
|
||||
_stree: SymbolTree = TreeNodeHelper.get_sub_tree(node)
|
||||
for inner_node in _stree.nodes():
|
||||
if inner_node.get_node_type() != NodeType.Output:
|
||||
continue
|
||||
position = _stree.before(inner_node)
|
||||
new_relu = nn.ReLU()
|
||||
new_relu_node = Node.create_call_cell(new_relu, targets=['x'], name='relu1',
|
||||
args=[ScopedValue.create_naming_value('x')])
|
||||
_stree.insert(position, new_relu_node)
|
||||
_stree.set_output(0, new_relu_node.get_targets()[0].value)
|
||||
break
|
||||
break
|
||||
|
||||
|
||||
def transform(stree: SymbolTree):
|
||||
add_relu_in_conv1(stree)
|
||||
replace_bn_in_conv2(stree)
|
||||
erase_relu_in_conv2(stree)
|
||||
inset_subtree(stree)
|
||||
|
||||
|
||||
def test_subtree_net():
|
||||
|
@ -115,7 +157,20 @@ def test_subtree_net():
|
|||
|
||||
net = MainNet()
|
||||
stree = SymbolTree.create(net)
|
||||
print(stree.get_code())
|
||||
transform(stree)
|
||||
for node in stree.nodes():
|
||||
print("after transform node name: ", node.get_name(), "; node type: ", node.get_node_type())
|
||||
if node.get_node_type() != NodeType.Tree:
|
||||
continue
|
||||
if node.get_name() == "conv":
|
||||
modify_stree: SymbolTree = TreeNodeHelper.get_sub_tree(node)
|
||||
for inner_node in modify_stree.nodes():
|
||||
print("inserted subtree node: ", inner_node.get_name())
|
||||
|
||||
inset_subtree2(stree)
|
||||
add_relu_in_conv11(stree)
|
||||
|
||||
print(stree.get_code())
|
||||
print(stree.get_handler().get_global_vars().keys())
|
||||
net_opt = stree.get_network()
|
||||
|
|
|
@ -21,6 +21,7 @@ from mindspore.rewrite import ScopedValue, ValueType, NodeType
|
|||
from mindspore.rewrite import Node as NodeApi
|
||||
from mindspore.rewrite.symbol_tree import SymbolTree
|
||||
from mindspore.rewrite.node import Node
|
||||
from .utils import get_symbol_tree_nodes_count
|
||||
|
||||
|
||||
class Network(Cell):
|
||||
|
@ -107,7 +108,7 @@ def test_insert_node():
|
|||
consumers = getattr(getattr(stree, "_topo_mgr"), "_target_consumer")
|
||||
providers_len = len(providers)
|
||||
consumers_len = len(consumers)
|
||||
assert len(stree.nodes()) == 7
|
||||
assert get_symbol_tree_nodes_count(stree) == 7
|
||||
assert len(construct_ast.body) == 6
|
||||
assert len(relu1.get_targets()) == 1
|
||||
assert len(relu2.get_normalized_args().values()) == 1
|
||||
|
@ -120,7 +121,7 @@ def test_insert_node():
|
|||
position = stree.before(relu2)
|
||||
node = stree.insert_node(position, node)
|
||||
# check nodes size
|
||||
assert len(stree.nodes()) == 8
|
||||
assert get_symbol_tree_nodes_count(stree) == 8
|
||||
# check args
|
||||
assert len(relu2.get_normalized_args().values()) == 1
|
||||
assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
|
||||
|
@ -158,7 +159,7 @@ def test_set_node_arg():
|
|||
Expectation: Success.
|
||||
"""
|
||||
stree, bn, relu1, relu2 = create_symbol_tree()
|
||||
assert len(stree.nodes()) == 7
|
||||
assert get_symbol_tree_nodes_count(stree) == 7
|
||||
assert len(bn.get_targets()) == 1
|
||||
bn_output = bn.get_targets()[0]
|
||||
# check bn topological order
|
||||
|
@ -210,7 +211,7 @@ def test_set_node_arg_by_node():
|
|||
Expectation: Success.
|
||||
"""
|
||||
stree, bn, relu1, relu2 = create_symbol_tree()
|
||||
assert len(stree.nodes()) == 7
|
||||
assert get_symbol_tree_nodes_count(stree) == 7
|
||||
assert len(bn.get_targets()) == 1
|
||||
bn_output = bn.get_targets()[0]
|
||||
# check bn topological order
|
||||
|
@ -265,13 +266,13 @@ def test_erase_succeed():
|
|||
construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
|
||||
providers = getattr(getattr(stree, "_topo_mgr"), "_target_provider")
|
||||
providers_len = len(providers)
|
||||
assert len(stree.nodes()) == 7
|
||||
assert get_symbol_tree_nodes_count(stree) == 7
|
||||
assert len(construct_ast.body) == 6
|
||||
|
||||
stree.set_node_arg_by_node(relu2, 0, bn)
|
||||
stree.erase_node(relu1)
|
||||
|
||||
assert len(stree.nodes()) == 6
|
||||
assert get_symbol_tree_nodes_count(stree) == 6
|
||||
assert len(providers) == providers_len - 1
|
||||
assert len(construct_ast.body) == 5
|
||||
|
||||
|
@ -300,13 +301,13 @@ def test_replace_one_to_one():
|
|||
stree, bn, relu1, relu2 = create_symbol_tree()
|
||||
construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
|
||||
assert len(construct_ast.body) == 6
|
||||
assert len(stree.nodes()) == 7
|
||||
assert get_symbol_tree_nodes_count(stree) == 7
|
||||
|
||||
new_conv = Conv2d(16, 16, 5)
|
||||
new_conv_node = NodeApi.create_call_cell(new_conv, [ScopedValue.create_naming_value("new_conv")],
|
||||
bn.get_targets()).get_handler()
|
||||
new_conv_node = stree.replace(relu1, [new_conv_node])
|
||||
assert len(stree.nodes()) == 7
|
||||
assert get_symbol_tree_nodes_count(stree) == 7
|
||||
# check ast
|
||||
assert len(construct_ast.body) == 6
|
||||
node_ast: ast.Assign = construct_ast.body[2]
|
||||
|
@ -341,7 +342,7 @@ def test_replace_one_to_multi():
|
|||
stree, bn, relu1, relu2 = create_symbol_tree()
|
||||
construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
|
||||
assert len(construct_ast.body) == 6
|
||||
assert len(stree.nodes()) == 7
|
||||
assert get_symbol_tree_nodes_count(stree) == 7
|
||||
|
||||
new_conv_node = NodeApi.create_call_cell(Conv2d(16, 16, 5), [ScopedValue.create_naming_value("new_conv")],
|
||||
bn.get_targets()).get_handler()
|
||||
|
@ -350,7 +351,7 @@ def test_replace_one_to_multi():
|
|||
new_relu_node = stree.replace(relu1, [new_relu_node, new_conv_node])
|
||||
new_conv_node = new_relu_node.get_inputs()[0]
|
||||
|
||||
assert len(stree.nodes()) == 8
|
||||
assert get_symbol_tree_nodes_count(stree) == 8
|
||||
# check ast
|
||||
assert len(construct_ast.body) == 7
|
||||
new_conv_ast: ast.Assign = construct_ast.body[2]
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from mindspore.rewrite.api.symbol_tree import SymbolTree
|
||||
|
||||
|
||||
def get_symbol_tree_nodes_count(stree: SymbolTree):
|
||||
count = 0
|
||||
for _ in stree.nodes():
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
def get_node_by_index(stree: SymbolTree, index):
|
||||
for i, node in enumerate(stree.nodes()):
|
||||
if i == index:
|
||||
return node
|
||||
|
||||
return None
|
Loading…
Reference in New Issue