diff --git a/mindspore/python/mindspore/rewrite/api/node.py b/mindspore/python/mindspore/rewrite/api/node.py index 8732a2c73ff..6936f38e44e 100644 --- a/mindspore/python/mindspore/rewrite/api/node.py +++ b/mindspore/python/mindspore/rewrite/api/node.py @@ -88,8 +88,8 @@ class Node: RuntimeError: If value of kwarg in `kwargs` is not a `NamingValue`-`ScopedValue` or a `CustomObjValue`-`ScopedValue`. """ - return Node(NodeImpl.create_call_buildin_op(cell, None, targets, ScopedValue.create_naming_value(name, "self"), - args, kwargs, name)) + return Node(NodeImpl.create_call_op(cell, None, targets, ScopedValue.create_naming_value(name, "self"), + args, kwargs, name)) def get_prev(self) -> 'Node': """ diff --git a/mindspore/python/mindspore/rewrite/api/symbol_tree.py b/mindspore/python/mindspore/rewrite/api/symbol_tree.py index 3e998f3b646..b428d446a63 100644 --- a/mindspore/python/mindspore/rewrite/api/symbol_tree.py +++ b/mindspore/python/mindspore/rewrite/api/symbol_tree.py @@ -51,14 +51,15 @@ class SymbolTree: """ return self._symbol_tree - def nodes(self) -> {}: + def nodes(self): """ Get all nodes of corresponding network. Returns: A dict mapping from name of node to node. """ - return [Node(node_impl) for node_impl in self._symbol_tree.nodes(unfold_subtree=False)] + for node in self._symbol_tree.nodes(): + yield Node(node) def get_node(self, node_name: str) -> Optional[Node]: """ diff --git a/mindspore/python/mindspore/rewrite/api/tree_node_helper.py b/mindspore/python/mindspore/rewrite/api/tree_node_helper.py index 10a66757206..b3a1d38c0f3 100644 --- a/mindspore/python/mindspore/rewrite/api/tree_node_helper.py +++ b/mindspore/python/mindspore/rewrite/api/tree_node_helper.py @@ -19,7 +19,6 @@ from .symbol_tree import SymbolTree from .node import Node from .node_type import NodeType from ..symbol_tree import SymbolTree as SymbolTreeImpl -from ..node import TreeNode class TreeNodeHelper: @@ -47,7 +46,6 @@ class TreeNodeHelper: if node.get_node_type() == NodeType.Tree: node_impl = node.get_handler() - assert isinstance(node_impl, TreeNode) subtree: SymbolTreeImpl = node_impl.symbol_tree if subtree is None: return None diff --git a/mindspore/python/mindspore/rewrite/namespace.py b/mindspore/python/mindspore/rewrite/namespace.py new file mode 100644 index 00000000000..7baa5d2a224 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/namespace.py @@ -0,0 +1,31 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Define the namespace of MindSpore op definition.""" +from .._extends.parse.namespace import CellNamespace + + +_ms_common_ns = CellNamespace('mindspore.common') +_ms_nn_ns = CellNamespace('mindspore.nn') +_ms_ops_ns = CellNamespace('mindspore.ops') + + +def is_subtree(cls_name): + """Determine whether 'cls_name' is a subtree.""" + if cls_name == "SequentialCell": + return True + if cls_name in _ms_common_ns or cls_name in _ms_nn_ns or cls_name in _ms_ops_ns: + return False + + return True diff --git a/mindspore/python/mindspore/rewrite/node.py b/mindspore/python/mindspore/rewrite/node.py index cb38f4b88d7..777933ff703 100644 --- a/mindspore/python/mindspore/rewrite/node.py +++ b/mindspore/python/mindspore/rewrite/node.py @@ -23,6 +23,8 @@ from mindspore import log as logger from .ast_helpers import AstModifier from .api.scoped_value import ScopedValue, ValueType from .api.node_type import NodeType +from .namespace import is_subtree +from .ast_helpers.ast_replacer import AstReplacer PASS_THROUGH_METHOD = ScopedValue.create_naming_value("PassThrough") @@ -223,6 +225,41 @@ class Node: return cls(NodeType.Output, ast_node, None, ScopedValue.create_naming_value("return"), real_return_values, {}, name, None) + @staticmethod + def create_call_op(op: Union[Cell, Primitive], ast_node: Optional[ast.AST], + targets: [Union[ScopedValue, str]], func: Union[ScopedValue, str], + args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None, name: str = ""): + """ + Static method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`. + If op is custom defined, it is treated by TreeNode. + A `CallCell` node represents an invoking to cell-op. + A `CallPrimitive` node represents an invoking to primitive-op. + + Args: + op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node. + ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast. + targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class. + func ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class. + args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class. + kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node` + class. + name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`. + Name of node also used as field name in network class. + """ + cls_name = type(op).__name__ + + if is_subtree(cls_name): + from .symbol_tree_builder import SymbolTreeBuilder + stb = SymbolTreeBuilder(op) + stree = stb.build() + replacer = AstReplacer(stree.get_class_ast()) + replacer.replace_all(stree.get_ori_cls_name(), stree.get_opt_cls_name()) + return TreeNode.create_tree_node(stree, None, targets, ScopedValue.create_naming_value(name, "self"), + args, kwargs, name, op) + + return Node.create_call_buildin_op(op, None, targets, ScopedValue.create_naming_value(name, "self"), + args, kwargs, name) + @staticmethod def _get_construct_arg_names(parameters): """ @@ -440,6 +477,10 @@ class Node: """ return self._prev + def set_prev(self, prev): + """Set previous node of current node in source code order. """ + self._prev = prev + def get_next(self) -> 'Node': """ Get next node of current node in source code order. @@ -449,6 +490,10 @@ class Node: """ return self._next + def set_next(self, _next): + """Set next node of current node in source code order.""" + self._next = _next + def has_same_ast(self, node: Union['Node', ast.AST]) -> bool: """ Check if other node holds same ast node with self. @@ -460,7 +505,7 @@ class Node: A bool. """ if isinstance(node, Node): - return self.has_same_ast(node._ast_node) + return self.has_same_ast(node.get_ast()) if isinstance(node, ast.AST): return id(self._ast_node) == id(node) return False @@ -570,7 +615,8 @@ class Node: keyword_map_index[keyword_ast.arg] = index for keyword_index in range(self._kwargs_num): key = self._normalized_args_keys[keyword_index + self._args_num] - AstModifier.update_arg_value(self._normalized_args.get(key), keywords_ast[keyword_map_index[key]].value) + AstModifier.update_arg_value(self._normalized_args.get(key), + keywords_ast[keyword_map_index.get(key)].value) def _sync_call_method_args_to_ast(self): """Sync args of ast.Cell of ast.Assign from self._normalized_args when NodeType is CallMethod.""" @@ -646,9 +692,9 @@ class Node: origin_prev: Optional[Node] = self._prev origin_next: Optional[Node] = self._next if origin_prev is not None: - origin_prev._next = origin_next + origin_prev.set_next(origin_next) if origin_next is not None: - origin_next._prev = origin_prev + origin_next.set_prev(origin_prev) self._prev = None self._next = None @@ -662,9 +708,9 @@ class Node: node.isolate() origin_prev: Optional[Node] = self._prev if origin_prev is not None: - origin_prev._next = node - node._prev = origin_prev - node._next = self + origin_prev.set_next(node) + node.set_prev(origin_prev) + node.set_next(self) self._prev = node def insert_after(self, node: 'Node'): @@ -677,10 +723,10 @@ class Node: node.isolate() origin_next: Optional[Node] = self._next self._next = node - node._prev = self - node._next = origin_next + node.set_prev(self) + node.set_next(origin_next) if origin_next is not None: - origin_next._prev = node + origin_next.set_prev(node) def get_inputs(self) -> ['Node']: """ @@ -850,12 +896,12 @@ class Node: if arg_idx >= self._args_num or arg_idx < 0: raise RuntimeError("arg_idx out of range: ", arg_idx) if out_idx is None: - if len(node._targets) != 1: + if len(node.get_targets()) != 1: raise RuntimeError("node should has one output when out_idx is not provided") out_idx = 0 - if out_idx >= len(node._targets): + if out_idx >= len(node.get_targets()): raise RuntimeError("out_idx out of range: ", out_idx) - new_arg = node._targets[out_idx] + new_arg = node.get_targets()[out_idx] self._normalized_args[self._normalized_args_keys[arg_idx]] = new_arg self._sync_arg() diff --git a/mindspore/python/mindspore/rewrite/node_visitor.py b/mindspore/python/mindspore/rewrite/node_visitor.py new file mode 100644 index 00000000000..0caa11a361c --- /dev/null +++ b/mindspore/python/mindspore/rewrite/node_visitor.py @@ -0,0 +1,44 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Visit nods of SymbolTree.""" + + +class NodeVisitor: + """Iterator class to access SymbolTree nodes""" + def __init__(self, stree): + self._stree = stree + self._nodes = [] + self._index = 0 + + def __iter__(self): + self._nodes = list(self._stree.get_nodes_dict().values()) + self._index = 0 + return self + + def __next__(self): + if self._index < len(self._nodes): + node = self._nodes[self._index] + self._index += 1 + return node + + raise StopIteration + + def append_node(self, node): + """append new node to iterator""" + self._nodes.append(node) + + def remove_node(self, node): + """remove node of iterator""" + self._nodes.remove(node) diff --git a/mindspore/python/mindspore/rewrite/parsers/assign_parser.py b/mindspore/python/mindspore/rewrite/parsers/assign_parser.py index c6606ab6171..afb33c1eee4 100644 --- a/mindspore/python/mindspore/rewrite/parsers/assign_parser.py +++ b/mindspore/python/mindspore/rewrite/parsers/assign_parser.py @@ -186,7 +186,6 @@ class AssignParser(Parser): return results def _is_subtree_cell(self, cell: Cell) -> bool: - assert isinstance(cell, Cell) return not type(cell).__name__ in self._cell_namespce @staticmethod @@ -353,7 +352,6 @@ class AssignParser(Parser): # self._subnet = SubNet1(global_vars.get("subnet_args")) # so a change in sub-network should also be identified as a change in main-network. # so main-network should observe sub-network - new_stree.reg_observer(stree) replacer = AstReplacer(new_stree.get_class_ast()) replacer.replace_all(new_stree.get_ori_cls_name(), new_stree.get_opt_cls_name()) return TreeNode(new_stree, father_ast_node, targets, func, call_args, call_kwargs, func_name, diff --git a/mindspore/python/mindspore/rewrite/parsers/class_def_parser.py b/mindspore/python/mindspore/rewrite/parsers/class_def_parser.py index 15d4737d284..e76601d4402 100644 --- a/mindspore/python/mindspore/rewrite/parsers/class_def_parser.py +++ b/mindspore/python/mindspore/rewrite/parsers/class_def_parser.py @@ -114,7 +114,6 @@ class ClassDefParser(Parser): def _is_subtree_field(self, ori_net, field) -> bool: op = getattr(ori_net, field) - assert op is not None return not type(op).__name__ in self._cell_namespace def _process_init_func_ast(self, stree: SymbolTree, init_ast: ast.FunctionDef): diff --git a/mindspore/python/mindspore/rewrite/symbol_tree.py b/mindspore/python/mindspore/rewrite/symbol_tree.py index 486210b3838..f63679c9a62 100644 --- a/mindspore/python/mindspore/rewrite/symbol_tree.py +++ b/mindspore/python/mindspore/rewrite/symbol_tree.py @@ -32,6 +32,7 @@ from .namer import TargetNamer, NodeNamer, ClassNamer from .common.observer import Observer from .common.observable import Observable from .common.event import Event +from .node_visitor import NodeVisitor class Position: @@ -113,6 +114,7 @@ class SymbolTree(Observer, Observable): self._return: Optional[Node] = None self._modified = False + self._node_visitor = None def finish_build(self): self.add_event(Event.TopologicalChangeEvent) @@ -258,25 +260,27 @@ class SymbolTree(Observer, Observable): raise RuntimeError("Key of global_vars duplicated:", key) self._global_vars[key] = value - def nodes(self, unfold_subtree=False): + def get_nodes_dict(self): + """Get dict of nodes""" + return self._nodes + + def nodes(self): """ Getter of nodes if current SymbolTree. - Args: - unfold_subtree (bool): Need to iterate into sub-symbol-tree recursively. - Returns: A list of instance of Nodes. """ - if unfold_subtree: - nodes = [] - for _, v in self._nodes.items(): - if isinstance(v, TreeNode): - nodes.extend(v.symbol_tree.nodes(unfold_subtree)) - else: - nodes.append(v) - return nodes - return self._nodes.values() + if self._node_visitor is None: + self._node_visitor = NodeVisitor(self) + it = iter(self._node_visitor) + + while True: + try: + n = next(it) + yield n + except StopIteration: + return None def get_node(self, node_name: str) -> Optional[Node]: """ @@ -431,19 +435,44 @@ class SymbolTree(Observer, Observable): # _unique_targets must called after _update_args_for_unique and _update_kwargs_for_unique self._unique_targets(node) self._insert_node(position, node) + if isinstance(node, TreeNode): + node.symbol_tree.reg_observer(self) + if self._node_visitor: + self._node_visitor.append_node(node) # update init-function-ast and construct-function-ast if insert_to_ast: node.set_func(ScopedValue.create_naming_value(node_name, "self")) - node_ast = node.get_ast() - if not isinstance(node_ast, ast.Assign): - raise RuntimeError("Only support insert cell op now") - AstModifier.insert_assign_to_function(self._init_func_ast, - targets=[ScopedValue(ValueType.NamingValue, "self", node_name)], - expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"), - args=[ScopedValue(ValueType.StringValue, "", node_name)]) - AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast, - None if position is None else position.node.get_ast(), - position.before_node) + if isinstance(node, TreeNode): + global_vars_key = node.get_name() + "_args" + self.add_global_vars(global_vars_key, node.symbol_tree.get_global_vars()) + args_call = AstModifier.create_call(ScopedValue.create_naming_value("get", "global_vars"), + [ScopedValue.create_variable_value(global_vars_key)]) + value = ast.Call(func=ast.Name(node.symbol_tree.get_opt_cls_name(), ast.Store(), lineno=0, + col_offset=0), args=[args_call], keywords=[], lineno=0, col_offset=0) + + ast_target = ast.Name("self." + node.get_name(), ast.Store(), lineno=0, col_offset=0) + assign = ast.Assign(targets=[ast_target], value=value, lineno=0, col_offset=0) + AstModifier.insert_assign_ast_to_function(self._init_func_ast, assign) + + assign_construct = AstModifier.create_call_assign(node.get_targets(), ScopedValue.create_naming_value + (node.get_name(), "self"), node.get_args(), {}) + AstModifier.insert_assign_ast_to_function(self._root_ast, assign_construct, + None if position is None else position.node.get_ast(), + position.before_node) + sub_stree: SymbolTree = node.symbol_tree + from .symbol_tree_builder import SymbolTreeBuilder + SymbolTreeBuilder.merge_module_of_subtree(self, sub_stree) + else: + node_ast = node.get_ast() + if not isinstance(node_ast, ast.Assign): + raise RuntimeError("Only support insert cell op now") + AstModifier.insert_assign_to_function(self._init_func_ast, + targets=[ScopedValue(ValueType.NamingValue, "self", node_name)], + expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"), + args=[ScopedValue(ValueType.StringValue, "", node_name)]) + AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast, + None if position is None else position.node.get_ast(), + position.before_node) self._global_vars[node_name] = node.get_instance() return node @@ -595,6 +624,8 @@ class SymbolTree(Observer, Observable): value.isolate() break self._topo_mgr.on_erase_node(node) + if self._node_visitor: + self._node_visitor.remove_node(node) return node def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node: diff --git a/mindspore/python/mindspore/rewrite/symbol_tree_builder.py b/mindspore/python/mindspore/rewrite/symbol_tree_builder.py index 33e78002c82..a29f249b9e9 100644 --- a/mindspore/python/mindspore/rewrite/symbol_tree_builder.py +++ b/mindspore/python/mindspore/rewrite/symbol_tree_builder.py @@ -44,6 +44,22 @@ class SymbolTreeBuilder: self._ast_root: ast.Module = ast.parse(network_str) self._root_tree: Optional[SymbolTree] = None + @staticmethod + def merge_module_of_subtree(main_tree: SymbolTree, sub_stree: SymbolTree): + """ + Merge ast.Module of sub-network into ast.Module of main-network. + + 1. Merge imports of ast.Module. + 2. Merge classes of ast.Module. + 3. Use merged ast.Module as module of main-network and sub-network. + """ + + father_mod = main_tree.get_module_ast() + sub_mod = sub_stree.get_module_ast() + SymbolTreeBuilder._merge_import_of_module(father_mod, sub_mod) + SymbolTreeBuilder._merge_class_of_module(father_mod, sub_mod) + sub_stree.set_module_ast(father_mod) + @staticmethod def _ast_transform(ast_root: ast.AST) -> ast.AST: """ @@ -80,7 +96,6 @@ class SymbolTreeBuilder: main_mod_finder = AstFinder(main_mod) imports_in_sub = copy(sub_mod_finder.find_all((ast.Import, ast.ImportFrom))) imports_in_main = copy(main_mod_finder.find_all((ast.Import, ast.ImportFrom))) - assert imports_in_main first_import = imports_in_main[0] for clazz in imports_in_sub: AstModifier.insert_sub_ast(main_mod, clazz, first_import, True) @@ -103,12 +118,11 @@ class SymbolTreeBuilder: main_mod_finder = AstFinder(main_mod) classes_in_sub = copy(sub_mod_finder.find_all(ast.ClassDef)) classes_in_main = copy(main_mod_finder.find_all(ast.ClassDef)) - assert classes_in_main first_class = classes_in_main[0] for clazz in classes_in_sub: AstModifier.insert_class_into_module(main_mod, clazz, first_class, True) - def _merge_module_of_subtree(self): + def _merge_module_of_subtrees(self): """ Merge ast.Module of all sub-networks into ast.Module of main-network. @@ -117,13 +131,9 @@ class SymbolTreeBuilder: 3. Use merged ast.Module as module of main-network and sub-network. """ - father_mod = self._root_tree.get_module_ast() for node in self._root_tree.nodes(): if isinstance(node, TreeNode): - sub_stree: SymbolTree = node.symbol_tree - SymbolTreeBuilder._merge_import_of_module(father_mod, sub_stree.get_module_ast()) - SymbolTreeBuilder._merge_class_of_module(father_mod, sub_stree.get_module_ast()) - sub_stree.set_module_ast(father_mod) + SymbolTreeBuilder.merge_module_of_subtree(self._root_tree, node.symbol_tree) def _reduce_redundant_import(self): """ @@ -140,7 +150,6 @@ class SymbolTreeBuilder: if isinstance(body, ast.Import): names = body.names for name in names: - assert isinstance(name, ast.alias) import_hash = hash((name.name, name.asname)) if import_hash in exist_import: continue @@ -150,7 +159,6 @@ class SymbolTreeBuilder: import_module = body.module names = body.names for name in names: - assert isinstance(name, ast.alias) import_hash = hash((import_module, name.name, name.asname)) if import_hash in exist_import_from: continue @@ -182,7 +190,7 @@ class SymbolTreeBuilder: self._root_tree: SymbolTree = SymbolTree(self._origin_net, self._ast_root) parser: Parser = ParserRegister.instance().get_parser(ast.Module) parser.process(self._root_tree, self._ast_root) - self._merge_module_of_subtree() + self._merge_module_of_subtrees() self._reduce_redundant_import() ast.fix_missing_locations(self._root_tree.get_module_ast()) self._root_tree.finish_build() diff --git a/tests/ut/python/rewrite/test_multiple_targets.py b/tests/ut/python/rewrite/test_multiple_targets.py index daaf250a537..3236d08fe03 100644 --- a/tests/ut/python/rewrite/test_multiple_targets.py +++ b/tests/ut/python/rewrite/test_multiple_targets.py @@ -15,6 +15,7 @@ from mindspore.nn import Cell, Conv2d from mindspore.rewrite import SymbolTree from mindspore.ops import operations as P +from .utils import get_node_by_index class SubNet(Cell): @@ -56,7 +57,8 @@ def test_multi_targets(): """ test_cls = NetMultiTargets() stree = SymbolTree.create(test_cls) - node = stree.nodes()[2] + node = get_node_by_index(stree, 2) + assert node is not None targets = node.get_targets() assert targets[0].value == 'c1' assert targets[1].value == 'c2' diff --git a/tests/ut/python/rewrite/test_net_simple.py b/tests/ut/python/rewrite/test_net_simple.py index 8a42f65a7f3..c6c6c2cc514 100644 --- a/tests/ut/python/rewrite/test_net_simple.py +++ b/tests/ut/python/rewrite/test_net_simple.py @@ -121,7 +121,6 @@ def erase_node_x_11(stree: SymbolTree): def transform(stree: SymbolTree): add_conv_before_flatten(stree) - add_my_cell_after_x_12(stree) erase_node_x_11(stree) diff --git a/tests/ut/python/rewrite/test_pattern_engine.py b/tests/ut/python/rewrite/test_pattern_engine.py index 73630d12208..8216e09302e 100644 --- a/tests/ut/python/rewrite/test_pattern_engine.py +++ b/tests/ut/python/rewrite/test_pattern_engine.py @@ -19,6 +19,7 @@ from mindspore.nn import Cell, Conv2d, BatchNorm2d, ReLU from mindspore.ops import Add, AddN from mindspore.rewrite import ScopedValue, Node, SymbolTree from mindspore.rewrite import PatternEngine, PatternNode, Replacement, VarNode +from .utils import get_symbol_tree_nodes_count def test_tree_pattern_match(): @@ -92,13 +93,13 @@ def test_one_to_one_pattern(): assert bn is not None assert relu1 is not None assert len(construct_ast.body) == 6 - assert len(stree.nodes()) == 7 + assert get_symbol_tree_nodes_count(stree) == 7 bn_replace = BnReplace() bn_replace.apply(stree) assert len(construct_ast.body) == 6 - assert len(stree.nodes()) == 7 + assert get_symbol_tree_nodes_count(stree) == 7 conv = stree.get_node("conv") bn = stree.get_node("bn") relu1 = stree.get_node("relu1") @@ -167,13 +168,13 @@ def test_one_to_multi_chain_pattern(): assert bn is not None assert relu1 is not None assert len(construct_ast.body) == 6 - assert len(stree.nodes()) == 7 + assert get_symbol_tree_nodes_count(stree) == 7 bn_replace = BnReplace() bn_replace.apply(stree) assert len(construct_ast.body) == 7 - assert len(stree.nodes()) == 8 + assert get_symbol_tree_nodes_count(stree) == 8 conv = stree.get_node("conv") bn = stree.get_node("bn") relu1 = stree.get_node("relu1") @@ -296,13 +297,13 @@ def test_tree_pattern(): assert relu2 is not None construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast") assert len(construct_ast.body) == 8 - assert len(stree.nodes()) == 9 + assert get_symbol_tree_nodes_count(stree) == 9 add_relu_pattern = AddReluPattern() add_relu_pattern.apply(stree) assert len(construct_ast.body) == 10 - assert len(stree.nodes()) == 11 + assert get_symbol_tree_nodes_count(stree) == 11 conv1 = stree.get_node("conv1") conv2 = stree.get_node("conv2") add = stree.get_node("add") @@ -481,13 +482,13 @@ def test_multi_input_to_multi_pattern_tree_pattern(): assert relu is not None construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast") assert len(construct_ast.body) == 6 - assert len(stree.nodes()) == 9 + assert get_symbol_tree_nodes_count(stree) == 9 multi_input_pattern = MultiInputPattern() multi_input_pattern.apply(stree) assert len(construct_ast.body) == 4 - assert len(stree.nodes()) == 7 + assert get_symbol_tree_nodes_count(stree) == 7 conv1 = stree.get_node("conv1") conv2 = stree.get_node("conv2") add1 = stree.get_node("add1") @@ -598,13 +599,13 @@ def test_one_input_to_multi_pattern_tree_pattern(): assert relu is not None construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast") assert len(construct_ast.body) == 6 - assert len(stree.nodes()) == 7 + assert get_symbol_tree_nodes_count(stree) == 7 multi_input_pattern = MultiInputPattern() multi_input_pattern.apply(stree) assert len(construct_ast.body) == 4 - assert len(stree.nodes()) == 5 + assert get_symbol_tree_nodes_count(stree) == 5 conv1 = stree.get_node("conv1") conv2 = stree.get_node("conv2") add1 = stree.get_node("add1") diff --git a/tests/ut/python/rewrite/test_subtree_net.py b/tests/ut/python/rewrite/test_subtree_net.py index 98318afa050..85d7c49871a 100644 --- a/tests/ut/python/rewrite/test_subtree_net.py +++ b/tests/ut/python/rewrite/test_subtree_net.py @@ -100,10 +100,52 @@ def erase_relu_in_conv2(stree: SymbolTree): break +def inset_subtree(stree: SymbolTree): + for node in stree.nodes(): + if node.get_name() == "conv2": + position = stree.before(node) + subtree = SubNet() + new_node = Node.create_call_cell(subtree, targets=[ScopedValue.create_naming_value('x')], name='conv', + args=[ScopedValue.create_naming_value('x')], kwargs={}) + stree.insert(position, new_node) + break + + +def inset_subtree2(stree: SymbolTree): + for node in stree.nodes(): + if node.get_name() == "conv2": + position = stree.before(node) + subtree = SubNet() + new_node = Node.create_call_cell(subtree, targets=[ScopedValue.create_naming_value('x')], name='conv11', + args=[ScopedValue.create_naming_value('x')], kwargs={}) + stree.insert(position, new_node) + break + + +def add_relu_in_conv11(stree: SymbolTree): + for node in stree.nodes(): + if node.get_node_type() != NodeType.Tree: + continue + if node.get_name() == "conv11": + _stree: SymbolTree = TreeNodeHelper.get_sub_tree(node) + for inner_node in _stree.nodes(): + if inner_node.get_node_type() != NodeType.Output: + continue + position = _stree.before(inner_node) + new_relu = nn.ReLU() + new_relu_node = Node.create_call_cell(new_relu, targets=['x'], name='relu1', + args=[ScopedValue.create_naming_value('x')]) + _stree.insert(position, new_relu_node) + _stree.set_output(0, new_relu_node.get_targets()[0].value) + break + break + + def transform(stree: SymbolTree): add_relu_in_conv1(stree) replace_bn_in_conv2(stree) erase_relu_in_conv2(stree) + inset_subtree(stree) def test_subtree_net(): @@ -115,7 +157,20 @@ def test_subtree_net(): net = MainNet() stree = SymbolTree.create(net) + print(stree.get_code()) transform(stree) + for node in stree.nodes(): + print("after transform node name: ", node.get_name(), "; node type: ", node.get_node_type()) + if node.get_node_type() != NodeType.Tree: + continue + if node.get_name() == "conv": + modify_stree: SymbolTree = TreeNodeHelper.get_sub_tree(node) + for inner_node in modify_stree.nodes(): + print("inserted subtree node: ", inner_node.get_name()) + + inset_subtree2(stree) + add_relu_in_conv11(stree) + print(stree.get_code()) print(stree.get_handler().get_global_vars().keys()) net_opt = stree.get_network() diff --git a/tests/ut/python/rewrite/test_symbol_tree.py b/tests/ut/python/rewrite/test_symbol_tree.py index 6599dd78804..a4a3efe7f5f 100644 --- a/tests/ut/python/rewrite/test_symbol_tree.py +++ b/tests/ut/python/rewrite/test_symbol_tree.py @@ -21,6 +21,7 @@ from mindspore.rewrite import ScopedValue, ValueType, NodeType from mindspore.rewrite import Node as NodeApi from mindspore.rewrite.symbol_tree import SymbolTree from mindspore.rewrite.node import Node +from .utils import get_symbol_tree_nodes_count class Network(Cell): @@ -107,7 +108,7 @@ def test_insert_node(): consumers = getattr(getattr(stree, "_topo_mgr"), "_target_consumer") providers_len = len(providers) consumers_len = len(consumers) - assert len(stree.nodes()) == 7 + assert get_symbol_tree_nodes_count(stree) == 7 assert len(construct_ast.body) == 6 assert len(relu1.get_targets()) == 1 assert len(relu2.get_normalized_args().values()) == 1 @@ -120,7 +121,7 @@ def test_insert_node(): position = stree.before(relu2) node = stree.insert_node(position, node) # check nodes size - assert len(stree.nodes()) == 8 + assert get_symbol_tree_nodes_count(stree) == 8 # check args assert len(relu2.get_normalized_args().values()) == 1 assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0] @@ -158,7 +159,7 @@ def test_set_node_arg(): Expectation: Success. """ stree, bn, relu1, relu2 = create_symbol_tree() - assert len(stree.nodes()) == 7 + assert get_symbol_tree_nodes_count(stree) == 7 assert len(bn.get_targets()) == 1 bn_output = bn.get_targets()[0] # check bn topological order @@ -210,7 +211,7 @@ def test_set_node_arg_by_node(): Expectation: Success. """ stree, bn, relu1, relu2 = create_symbol_tree() - assert len(stree.nodes()) == 7 + assert get_symbol_tree_nodes_count(stree) == 7 assert len(bn.get_targets()) == 1 bn_output = bn.get_targets()[0] # check bn topological order @@ -265,13 +266,13 @@ def test_erase_succeed(): construct_ast: ast.FunctionDef = getattr(stree, "_root_ast") providers = getattr(getattr(stree, "_topo_mgr"), "_target_provider") providers_len = len(providers) - assert len(stree.nodes()) == 7 + assert get_symbol_tree_nodes_count(stree) == 7 assert len(construct_ast.body) == 6 stree.set_node_arg_by_node(relu2, 0, bn) stree.erase_node(relu1) - assert len(stree.nodes()) == 6 + assert get_symbol_tree_nodes_count(stree) == 6 assert len(providers) == providers_len - 1 assert len(construct_ast.body) == 5 @@ -300,13 +301,13 @@ def test_replace_one_to_one(): stree, bn, relu1, relu2 = create_symbol_tree() construct_ast: ast.FunctionDef = getattr(stree, "_root_ast") assert len(construct_ast.body) == 6 - assert len(stree.nodes()) == 7 + assert get_symbol_tree_nodes_count(stree) == 7 new_conv = Conv2d(16, 16, 5) new_conv_node = NodeApi.create_call_cell(new_conv, [ScopedValue.create_naming_value("new_conv")], bn.get_targets()).get_handler() new_conv_node = stree.replace(relu1, [new_conv_node]) - assert len(stree.nodes()) == 7 + assert get_symbol_tree_nodes_count(stree) == 7 # check ast assert len(construct_ast.body) == 6 node_ast: ast.Assign = construct_ast.body[2] @@ -341,7 +342,7 @@ def test_replace_one_to_multi(): stree, bn, relu1, relu2 = create_symbol_tree() construct_ast: ast.FunctionDef = getattr(stree, "_root_ast") assert len(construct_ast.body) == 6 - assert len(stree.nodes()) == 7 + assert get_symbol_tree_nodes_count(stree) == 7 new_conv_node = NodeApi.create_call_cell(Conv2d(16, 16, 5), [ScopedValue.create_naming_value("new_conv")], bn.get_targets()).get_handler() @@ -350,7 +351,7 @@ def test_replace_one_to_multi(): new_relu_node = stree.replace(relu1, [new_relu_node, new_conv_node]) new_conv_node = new_relu_node.get_inputs()[0] - assert len(stree.nodes()) == 8 + assert get_symbol_tree_nodes_count(stree) == 8 # check ast assert len(construct_ast.body) == 7 new_conv_ast: ast.Assign = construct_ast.body[2] diff --git a/tests/ut/python/rewrite/utils.py b/tests/ut/python/rewrite/utils.py new file mode 100644 index 00000000000..61297253af0 --- /dev/null +++ b/tests/ut/python/rewrite/utils.py @@ -0,0 +1,31 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore.rewrite.api.symbol_tree import SymbolTree + + +def get_symbol_tree_nodes_count(stree: SymbolTree): + count = 0 + for _ in stree.nodes(): + count += 1 + return count + + +def get_node_by_index(stree: SymbolTree, index): + for i, node in enumerate(stree.nodes()): + if i == index: + return node + + return None