From 7938d845a06de1e403128e9e7deda0e0408a60b9 Mon Sep 17 00:00:00 2001 From: yuzhenhua Date: Thu, 3 Nov 2022 19:32:08 +0800 Subject: [PATCH] fix ckpt param name bug --- .jenkins/check/config/filter_pylint.txt | 1 + .../python/mindspore/nn/layer/container.py | 26 ++ .../python/mindspore/rewrite/api/node_type.py | 1 + .../mindspore/rewrite/api/pattern_engine.py | 31 ++ .../rewrite/ast_helpers/ast_modifier.py | 6 +- .../python/mindspore/rewrite/namespace.py | 2 - mindspore/python/mindspore/rewrite/node.py | 86 +++- .../rewrite/parsers/assign_parser.py | 47 +- .../rewrite/parsers/class_def_parser.py | 20 +- .../mindspore/rewrite/parsers/for_parser.py | 7 +- .../rewrite/parsers/function_def_parser.py | 2 +- .../python/mindspore/rewrite/symbol_tree.py | 80 ++-- mindspore/python/mindspore/train/amp.py | 12 +- tests/ut/python/rewrite/test_cellcontainer.py | 402 ++++++++++++++++++ tests/ut/python/rewrite/test_for.py | 45 +- tests/ut/python/rewrite/test_net_simple.py | 1 - 16 files changed, 694 insertions(+), 75 deletions(-) create mode 100644 tests/ut/python/rewrite/test_cellcontainer.py diff --git a/.jenkins/check/config/filter_pylint.txt b/.jenkins/check/config/filter_pylint.txt index 1146bd04adb..b22e5dba2cc 100644 --- a/.jenkins/check/config/filter_pylint.txt +++ b/.jenkins/check/config/filter_pylint.txt @@ -169,6 +169,7 @@ "mindspore/tests/ut/python/mindir/test_mindir_export.py" "no-else-return" "mindspore/tests/" "c-extension-no-member" "mindspore/tests/st/parameter/test_parameter_celllist.py" "protected-access" +"mindspore/tests/ut/python/rewrite/test_cellcontainer.py" "protected-access" #MindSpore Lite "mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/HPC-generator/generator.py" "redefined-builtin" diff --git a/mindspore/python/mindspore/nn/layer/container.py b/mindspore/python/mindspore/nn/layer/container.py index bda548191d3..bc8c96aedc2 100644 --- a/mindspore/python/mindspore/nn/layer/container.py +++ b/mindspore/python/mindspore/nn/layer/container.py @@ -279,6 +279,32 @@ class SequentialCell(Cell): input_data = cell(input_data) return input_data + def _insert(self, index, cell): + """ + Inserts a given Cell before a given index in the list. + + Args: + index(int): The Insert index in the CellList. + cell(Cell): The Cell to be inserted. + """ + cls_name = self.__class__.__name__ + idx = _valid_index(len(self), index, cls_name) + _valid_cell(cell, cls_name) + length = len(self) + prefix, key_index = _get_prefix_and_index(self._cells) + while length > idx: + if self._auto_prefix: + tmp_cell = self._cells[str(length-1)] + for _, param in tmp_cell.parameters_and_names(): + param.name = prefix + str(length) + "." + ".".join(param.name.split(".")[key_index+1:]) + self._cells[str(length)] = self._cells[str(length - 1)] + length -= 1 + self._cells[str(idx)] = cell + if self._auto_prefix: + cell.update_parameters_name(prefix + str(idx) + ".") + self.cell_list = list(self._cells.values()) + self._is_dynamic_name.insert(index, True) + class CellList(_CellListBase, Cell): """ diff --git a/mindspore/python/mindspore/rewrite/api/node_type.py b/mindspore/python/mindspore/rewrite/api/node_type.py index 65f7e4e64a9..0c81a1e6420 100644 --- a/mindspore/python/mindspore/rewrite/api/node_type.py +++ b/mindspore/python/mindspore/rewrite/api/node_type.py @@ -43,3 +43,4 @@ class NodeType(Enum): Input = 7 Output = 8 Tree = 9 + CellContainer = 10 diff --git a/mindspore/python/mindspore/rewrite/api/pattern_engine.py b/mindspore/python/mindspore/rewrite/api/pattern_engine.py index 5ae95035337..87af14e0e3c 100644 --- a/mindspore/python/mindspore/rewrite/api/pattern_engine.py +++ b/mindspore/python/mindspore/rewrite/api/pattern_engine.py @@ -308,6 +308,16 @@ class PatternEngine: queue.extend(inputs_dict.get(cur_node.get_name())) return new_root + @staticmethod + def _multi_replace_cellcontainer(stree, cellcontainer, node, matched_dict, new_nodes): + """Replace node in CellContainer.""" + to_erase_list = list(matched_dict.values()) + stree.replace(Node(node), new_nodes) + for n in reversed(to_erase_list): + if n.get_handler() is node: + continue + stree.erase_node(n) + def apply(self, stree: SymbolTree) -> bool: """ Apply current pattern to a `SymbolTree`. @@ -359,6 +369,9 @@ class PatternEngine: visited.append(cur_node) queue.extend(cur_node.get_users()) continue + if cur_node.get_node_type() == NodeType.CellContainer: + self._process_cellcontainer(stree, cur_node.get_handler()) + continue visited.append(cur_node) matched, matched_dict = self._match(self._pattern, cur_node) # not matched @@ -460,3 +473,21 @@ class PatternEngine: logger.debug("Check match failed, pattern leaked") return False return True + + def _process_cellcontainer(self, stree, cellcontainer): + """Process CellContainer node.""" + for node in cellcontainer.nodes(): + if node.get_node_type() == NodeType.Tree: + subtree = node.symbol_tree + self.apply(SymbolTree(subtree)) + continue + else: + matched, matched_dict = self._match(self._pattern, Node(node)) + if not matched: + continue + new_nodes = [] + if self._replacement is not None: + new_nodes = self._replacement(self._pattern, self._is_chain, matched_dict) + if not new_nodes: # if replacement is empty, do nothing + continue + PatternEngine._multi_replace_cellcontainer(stree, cellcontainer, node, matched_dict, new_nodes) diff --git a/mindspore/python/mindspore/rewrite/ast_helpers/ast_modifier.py b/mindspore/python/mindspore/rewrite/ast_helpers/ast_modifier.py index 02894bb2031..2cec9ab6316 100644 --- a/mindspore/python/mindspore/rewrite/ast_helpers/ast_modifier.py +++ b/mindspore/python/mindspore/rewrite/ast_helpers/ast_modifier.py @@ -241,8 +241,10 @@ class AstModifier(ast.NodeTransformer): An instance of ast.Assign which has been appended to 'init_func'. """ return AstModifier.insert_assign_to_function(init_func, targets=targets, - args=[ScopedValue.create_variable_value(field)], - expr=ScopedValue(ValueType.NamingValue, "global_vars", "get")) + expr=ScopedValue(ValueType.NamingValue, "", "setattr"), + args=[ScopedValue(ValueType.NamingValue, "obj"), + ScopedValue.create_variable_value(field)]) + @staticmethod def create_call_assign(targets: [ScopedValue], expr: ScopedValue, args: [ScopedValue], diff --git a/mindspore/python/mindspore/rewrite/namespace.py b/mindspore/python/mindspore/rewrite/namespace.py index 33299d650ef..8cd81df4f1b 100644 --- a/mindspore/python/mindspore/rewrite/namespace.py +++ b/mindspore/python/mindspore/rewrite/namespace.py @@ -24,8 +24,6 @@ _ms_functional_ns = CellNamespace('mindspore.ops.functional') def is_subtree(cls_name): """Determine whether 'cls_name' is a subtree.""" - if cls_name == "SequentialCell": - return True if cls_name == "QuantizeWrapperCell": return False if cls_name in _ms_common_ns or cls_name in _ms_nn_ns or cls_name in _ms_ops_ns: diff --git a/mindspore/python/mindspore/rewrite/node.py b/mindspore/python/mindspore/rewrite/node.py index ac84b1f962f..32f249939e5 100644 --- a/mindspore/python/mindspore/rewrite/node.py +++ b/mindspore/python/mindspore/rewrite/node.py @@ -624,7 +624,7 @@ class Node: """ self._targets = targets if self._node_type in (NodeType.CallCell, NodeType.CallMethod, NodeType.CallPrimitive, - NodeType.Tree, NodeType.CallFunction): + NodeType.Tree, NodeType.CallFunction, NodeType.CellContainer): self._sync_assign_targets_to_ast() def get_func(self) -> ScopedValue: @@ -1135,7 +1135,7 @@ class Node: def _sync_arg(self): """Sync _normalized_args to corresponding ast node when updated.""" - if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree): + if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree, NodeType.CellContainer): self._sync_call_cell_args_to_ast() elif self._node_type == NodeType.Output: self._sync_return_node_to_ast() @@ -1198,3 +1198,85 @@ class TreeNode(Node): if ast_node is None: ast_node = AstModifier.create_call_assign(new_targets, func, non_custom_args, non_custom_kwargs) return cls(tree, ast_node, new_targets, func, args, kwargs, name, instance) + + +class CellContainer(Node): + """ Container for saving cell-objects node. """ + class _Visitor(): + """ A iterator of CellContainer nodes. """ + def __init__(self, cellcontainer): + self._cellcontainer = cellcontainer + + def __len__(self): + """ Get the number of nodes. """ + return self._cellcontainer.node_count + + def __iter__(self): + """Create an iterator over the CellContainer.""" + count = len(self._cellcontainer.node_list) + i = 0 + while i < count: + curr = self._cellcontainer.node_list[i] + if curr.valid: + yield curr + i += 1 + + def __init__(self, ast_node: ast.AST, targets: [ScopedValue], func: ScopedValue, + args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance): + """Constructor of CellContainer. + + Args: + ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast. + targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. + func ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class. + args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. + kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class. + name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree. + Name of node also used as field name in network class. + instance: Object in network corresponding to this node. + """ + if isinstance(func, str): + func = ScopedValue.create_naming_value(func) + super().__init__(NodeType.CellContainer, ast_node, targets, func, args, kwargs, name, instance) + self._node_list = list() + self._node_count = 0 + + @property + def node_count(self): + """Number of nodes.""" + return self._node_count + + @node_count.setter + def node_count(self, count): + """Set number of nodes.""" + self._node_count = count + + @property + def node_list(self): + """ Get node list. """ + return self._node_list + + def append(self, node): + """ Append new node to node list. """ + self._node_list.append(node) + self.get_instance().append(node.get_instance()) + self.node_count += 1 + + def erase(self, node): + """Erase node form container.""" + index = self.node_list.index(node) + setattr(node, "valid", False) + self.node_count -= 1 + index = self.get_instance().cell_list.index(node.get_instance()) + del self.get_instance()[index] + + def insert(self, index, node): + """Insert node into container""" + self.node_list.insert(index, node) + setattr(node, "valid", True) + self.get_instance()._insert(index, node.get_instance()) + self.node_count += 1 + + def nodes(self): + """ Return a iterator of node.""" + return self._Visitor(self) diff --git a/mindspore/python/mindspore/rewrite/parsers/assign_parser.py b/mindspore/python/mindspore/rewrite/parsers/assign_parser.py index 4d4074c1509..77375cc8bc7 100644 --- a/mindspore/python/mindspore/rewrite/parsers/assign_parser.py +++ b/mindspore/python/mindspore/rewrite/parsers/assign_parser.py @@ -19,13 +19,13 @@ import astunparse from mindspore import log as logger from mindspore._extends.parse.namespace import CellNamespace -from mindspore.nn import Cell +from mindspore.nn import Cell, SequentialCell from mindspore.ops import operations as P from mindspore.ops import Primitive from mindspore.rewrite.parser_register import ParserRegister from mindspore.rewrite.namespace import is_subtree, is_functional, get_functional from mindspore.rewrite.symbol_tree import SymbolTree -from mindspore.rewrite.node import Node, TreeNode +from mindspore.rewrite.node import Node, TreeNode, CellContainer from mindspore.rewrite.parser import Parser from mindspore.rewrite.parser_register import reg_parser from mindspore.rewrite.api.scoped_value import ScopedValue, ValueType @@ -286,10 +286,10 @@ class AssignParser(Parser): if target.attr != func_name: continue changed = True - global_vars_key = "_".join([func_name, "args"]) - stree.add_global_vars(global_vars_key, sub_tree.get_global_vars()) - args_call = AstModifier.create_call(ScopedValue.create_naming_value("get", "global_vars"), - [ScopedValue.create_variable_value(global_vars_key)]) + setattr(stree.get_origin_network(), func_name, sub_tree.get_origin_network()) + args_call = AstModifier.create_call(ScopedValue(ValueType.NamingValue, "", "getattr"), + [ScopedValue(ValueType.NamingValue, "", "obj"), + ScopedValue(ValueType.StringValue, "", func_name)]) body.value = ast.Call(func=ast.Name(class_name, ast.Store()), args=[args_call], keywords=[]) break return changed @@ -308,6 +308,37 @@ class AssignParser(Parser): call_args = [AssignParser._create_scopedvalue(arg) for arg in father_ast_node.value.args] return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, {}) + def _cell_container_process(self, ast_node, stree, targets, func, call_args, call_kwargs, op_name, container_obj): + """ parse cell container object.""" + cell_container = CellContainer(ast_node, targets, func, call_args, call_kwargs, op_name, container_obj) + for i, cell in enumerate(container_obj): + is_sub_tree = is_subtree(type(cell).__name__) + if is_sub_tree: + stb = SymbolTreeBuilder(cell) + new_stree = stb.build() + replacer = AstReplacer(new_stree.get_class_ast()) + replacer.replace_all(new_stree.get_ori_cls_name(), new_stree.get_opt_cls_name()) + tree = TreeNode.create_tree_node(new_stree, ast_node, targets, func, call_args, call_kwargs, + type(cell).__name__, cell) + setattr(tree, "container", cell_container) + setattr(tree, "valid", True) + tree.set_belong_symbol_tree(stree) + cell_container.node_list.append(tree) + cell_container.node_count += 1 + if i > 0: + tree.set_inputs([cell_container.node_list[i-1]]) + else: + node = Node.create_call_buildin_op(cell, ast_node, targets, func, call_args, call_kwargs, + type(cell).__name__) + setattr(node, "container", cell_container) + setattr(node, "valid", True) + node.set_belong_symbol_tree(stree) + cell_container.node_list.append(node) + cell_container.node_count += 1 + if i > 0: + node.set_inputs([cell_container.node_list[i-1]]) + return cell_container + def _convert_ast_call_to_node(self, ast_node: ast.Call, father_ast_node: ast.Assign, stree: SymbolTree) -> Node: """ Convert ast.Call to a symbol tree node. @@ -343,6 +374,10 @@ class AssignParser(Parser): return node raise RuntimeError(error_str(f"operator instance undefined.", child_node=ast_node.func, father_node=ast_node)) + if isinstance(op, SequentialCell): + node = self._cell_container_process(father_ast_node, stree, targets, func, call_args, call_kwargs, + func_name, op) + return node if isinstance(op, Primitive): return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, call_kwargs, func_name) if isinstance(op, Cell): diff --git a/mindspore/python/mindspore/rewrite/parsers/class_def_parser.py b/mindspore/python/mindspore/rewrite/parsers/class_def_parser.py index f990bbc2fab..74c803c674d 100644 --- a/mindspore/python/mindspore/rewrite/parsers/class_def_parser.py +++ b/mindspore/python/mindspore/rewrite/parsers/class_def_parser.py @@ -21,8 +21,7 @@ from mindspore._extends.parse.namespace import CellNamespace from ..symbol_tree import SymbolTree from ..parser import Parser from ..parser_register import ParserRegister, reg_parser -from ..api.scoped_value import ScopedValue -from ..ast_helpers import AstReplacer, AstModifier +from ..ast_helpers import AstReplacer from ..common import error_str @@ -124,9 +123,6 @@ class ClassDefParser(Parser): super_index = ClassDefParser._find_super_expr_of_init_func(init_ast) ClassDefParser._modify_arguments_of_init_func(init_ast) self._replace_ori_field_of_init_func(stree, init_ast.body, super_index) - # re-find super_index for init_func changed in _replace_ori_field_of_init_func - super_index = ClassDefParser._find_super_expr_of_init_func(init_ast) - ClassDefParser._insert_handler_to_init_func(init_ast, super_index) @staticmethod def _find_super_expr_of_init_func(ast_init_fn: ast.FunctionDef) -> int: @@ -158,7 +154,7 @@ class ClassDefParser(Parser): def _modify_arguments_of_init_func(ast_init_fn: ast.FunctionDef): """Replace init function input parameters to self and global_vars.""" arg_self = ast.arg(arg="self", annotation="") - arg_global_vars = ast.arg(arg="global_vars", annotation="") + arg_global_vars = ast.arg(arg="obj", annotation="") ast_init_fn.args = ast.arguments(args=[arg_self, arg_global_vars], posonlyargs=[], kwonlyargs=[], kw_defaults=[], defaults=[], vararg=None, kwarg=None) ast.fix_missing_locations(ast_init_fn) @@ -235,22 +231,12 @@ class ClassDefParser(Parser): continue field_name = target.attr body.value = ast.Call(ast.Name('getattr', ast.Load()), - [ast.Attribute(ast.Name('self', ast.Load()), '_handler', ast.Load()), + [ast.Name('obj', ast.Load()), ast.Constant(value=field_name, kind=None)], []) for counter, index in enumerate(body_index_to_be_deleted): bodies.pop(index - counter) ClassDefParser._remove_empty_ast_in_init_func(bodies) - @staticmethod - def _insert_handler_to_init_func(ast_init_fn: ast.FunctionDef, super_index): - """Insert 'self._handler = global_vars.get('handler')' to init ast.FunctionDef.body""" - if super_index == -1: - super_index = 0 - AstModifier.insert_assign_to_function(ast_init_fn, [ScopedValue.create_naming_value("_handler", "self")], - ScopedValue.create_naming_value("get", "global_vars"), - [ScopedValue.create_variable_value("handler")], None, - ast_init_fn.body[super_index], False) - def process(self, stree: SymbolTree, node: ast.ClassDef): """ Parse init and construct in ast.ClassDef. diff --git a/mindspore/python/mindspore/rewrite/parsers/for_parser.py b/mindspore/python/mindspore/rewrite/parsers/for_parser.py index 653e2dcbe68..cc6e39d8e69 100644 --- a/mindspore/python/mindspore/rewrite/parsers/for_parser.py +++ b/mindspore/python/mindspore/rewrite/parsers/for_parser.py @@ -34,12 +34,13 @@ class ForParser(Parser): def modify_init_ast(stree, i, obj, iter_var_name): """Modify the ast node in init function.""" target = f"{iter_var_name.strip()}_{str(i)}" - stree.add_global_vars(target, obj) + setattr(stree.get_origin_network(), target, obj) stree.get_origin_network().insert_child_to_cell(target, obj) AstModifier.insert_assign_to_function(stree.get_init_func_ast(), targets=[ScopedValue(ValueType.NamingValue, "self", target)], - expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"), - args=[ScopedValue(ValueType.StringValue, "", target)]) + expr=ScopedValue(ValueType.NamingValue, "", "getattr"), + args=[ScopedValue(ValueType.NamingValue, "", "obj"), + ScopedValue(ValueType.StringValue, "", target)]) @staticmethod def modify_construct_ast(stree, ast_node, old_name, new_name): diff --git a/mindspore/python/mindspore/rewrite/parsers/function_def_parser.py b/mindspore/python/mindspore/rewrite/parsers/function_def_parser.py index 8c085bc7f82..41748b426a5 100644 --- a/mindspore/python/mindspore/rewrite/parsers/function_def_parser.py +++ b/mindspore/python/mindspore/rewrite/parsers/function_def_parser.py @@ -44,7 +44,7 @@ class FunctionDefParser(Parser): else: parser.process(stree, body) - for body in node.body: + for body in node.body[::-1]: if isinstance(body, (ast.For, ast.If)): node.body.remove(body) if hasattr(node, "decorator_list"): diff --git a/mindspore/python/mindspore/rewrite/symbol_tree.py b/mindspore/python/mindspore/rewrite/symbol_tree.py index 4354e5097e2..3cb1d6b4897 100644 --- a/mindspore/python/mindspore/rewrite/symbol_tree.py +++ b/mindspore/python/mindspore/rewrite/symbol_tree.py @@ -160,7 +160,6 @@ class SymbolTree(Observer, Observable): self._topo_mgr = TopoManager() self._topo_mgr.reg_observer(self) - self._global_vars: {str, object} = {origin_network_key: origin_network} self._nodes: {str, Node} = {} # parameters of forward method self._inputs: [Node] = [] @@ -484,17 +483,6 @@ class SymbolTree(Observer, Observable): """ return self._origin_network - def get_global_vars(self): - """Get global variables.""" - return self._global_vars - - def add_global_vars(self, key: str, value): - """Add global variables.""" - if self._global_vars.get(key) is not None: - logger.info(f"The key '{key}' is duplicated") - return - self._global_vars[key] = value - def get_nodes_dict(self): """Get dict of nodes""" return self._nodes @@ -614,7 +602,6 @@ class SymbolTree(Observer, Observable): RuntimeError: If 'node_or_name' is not belong to this SymbolTree or any sub-SymbolTree of current SymbolTree. """ - node = self._get_real_node(node_or_name) if node is None: raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name) @@ -653,7 +640,12 @@ class SymbolTree(Observer, Observable): RuntimeError: If 'position' is not in current SymbolTree. RuntimeError: If corresponding ast node is not an ast.Assign when 'insert_to_ast' is True. """ - + if position is not None and hasattr(position.node, "container"): + cellcontainer = getattr(position.node, "container") + index = cellcontainer.node_list.index(position.node) + index = index if position.before_node else index + 1 + cellcontainer.insert(index, node) + return node # if position in current SymbolTree if position is not None and position.symbol_tree is not self: raise RuntimeError("Position is not in current SymbolTree:", position) @@ -683,10 +675,10 @@ class SymbolTree(Observer, Observable): if not isinstance(node_ast, ast.Assign): raise RuntimeError("Only support insert cell op now") if isinstance(node, TreeNode): - global_vars_key = node.get_name() + "_args" - self.add_global_vars(global_vars_key, node.symbol_tree.get_global_vars()) - args_call = AstModifier.create_call(ScopedValue.create_naming_value("get", "global_vars"), - [ScopedValue.create_variable_value(global_vars_key)]) + setattr(self._origin_network, node.get_name(), node.get_instance()) + args_call = AstModifier.create_call(ScopedValue(ValueType.NamingValue, "", "getattr"), + [ScopedValue(ValueType.NamingValue, "", "obj"), + ScopedValue(ValueType.StringValue, "", node.get_name())]) value = ast.Call(func=ast.Name(node.symbol_tree.get_opt_cls_name(), ast.Store(), lineno=0, col_offset=0), args=[args_call], keywords=[], lineno=0, col_offset=0) @@ -703,12 +695,13 @@ class SymbolTree(Observer, Observable): else: AstModifier.insert_assign_to_function(self._init_func_ast, targets=[ScopedValue(ValueType.NamingValue, "self", node_name)], - expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"), - args=[ScopedValue(ValueType.StringValue, "", node_name)]) + expr=ScopedValue(ValueType.NamingValue, "", "getattr"), + args=[ScopedValue(ValueType.NamingValue, "", "obj"), + ScopedValue(ValueType.StringValue, "", node_name)]) AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast, None if position is None else position.node.get_ast(), position.before_node) - self._global_vars[node_name] = node.get_instance() + setattr(self._origin_network, node_name, node.get_instance()) return node def append_node(self, node: Node, append_to_ast: bool = True) -> Node: @@ -851,6 +844,10 @@ class SymbolTree(Observer, Observable): node = self._get_real_node(node_or_name) if node is None: raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name) + if hasattr(node, "container"): + cellcontainer = getattr(node, "container") + cellcontainer.erase(node) + return node ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast()) if not ret: raise RuntimeError("node not in function ast tree.") @@ -884,6 +881,9 @@ class SymbolTree(Observer, Observable): RuntimeError: If 'old_node' is not belong to current SymbolTree. """ + if hasattr(old_node, "container"): + self._replace_container_node(old_node, new_nodes) + return new_nodes[0] real_old_node = self._get_real_node(old_node) if real_old_node is None: raise RuntimeError("Old node is not belong to current SymbolTree:", old_node) @@ -1026,7 +1026,7 @@ class SymbolTree(Observer, Observable): A network object. """ cls = self._get_cls_through_file() - return cls(self._global_vars) + return cls(self._origin_network) def set_saved_file_name(self, file_name: str): """Sets the filename used to save the network.""" @@ -1070,6 +1070,14 @@ class SymbolTree(Observer, Observable): else: body.names.remove(alias) + def _replace_container_node(self, old_node, new_nodes): + cellcontainer = getattr(old_node, "container") + index = cellcontainer.node_list.index(old_node) + for n in reversed(new_nodes): + cellcontainer.insert(index, n) + index = cellcontainer.node_list.index(old_node) + cellcontainer.erase(old_node) + def _filter_out_to_delete_field(self, to_delete_field): """filter out used field from `to_delete_field`""" # filter _handler field @@ -1077,7 +1085,8 @@ class SymbolTree(Observer, Observable): to_delete_field.pop("_handler") # filter field used in node of construct for node in self._nodes.values(): - if node.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree): + if node.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree, + NodeType.CellContainer): func: ScopedValue = node.get_func() if func.scope == "self" and to_delete_field.get(func.value): to_delete_field.pop(func.value) @@ -1144,12 +1153,9 @@ class SymbolTree(Observer, Observable): self._module_ast.body.remove(body) def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]: - if isinstance(node_or_name, Node): - result = self.get_node(node_or_name.get_name()) - return result if result is node_or_name else None if isinstance(node_or_name, str): return self.get_node(node_or_name) - return None + return node_or_name def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node: """ @@ -1298,7 +1304,7 @@ class SymbolTree(Observer, Observable): raise TypeError("value should be ScopedValue, got: ", type(value)) if value.type == ValueType.CustomObjValue: field = self._node_name_namer.get_name(f"var_{type(value.value).__name__}") - self._global_vars[field] = value.value + setattr(self._origin_network, field, value.value) init_targets = [ScopedValue.create_naming_value(field, "self")] AstModifier.append_global_vars_expr_to_init(self._init_func_ast, init_targets, field) result[arg] = init_targets[0] @@ -1316,7 +1322,8 @@ class SymbolTree(Observer, Observable): Returns: A class handle. """ - file_name = "new_network_{0}.py".format(int(time.time() * 10000)) + self._update_container() + file_name = "new_network_{0}.py".format(int(time.time() * 10000000)) with os.fdopen(os.open(file_name, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f: source = self.get_code() f.write(source.encode('utf-8')) @@ -1333,3 +1340,18 @@ class SymbolTree(Observer, Observable): def _on_change(self, event: Event): self._modified = True self.changed(event) + + def _update_container(self): + """Update instance of node in container.""" + for node in self.nodes(): + index = 0 + if node.get_node_type() == NodeType.CellContainer: + for n in node.node_list: + if not n.valid: + continue + if n.get_node_type() == NodeType.Tree: + obj = n.symbol_tree.get_network() + node.get_instance()[index] = obj + else: + node.get_instance()[index] = n.get_instance() + index += 1 diff --git a/mindspore/python/mindspore/train/amp.py b/mindspore/python/mindspore/train/amp.py index 1a262bd7bff..22d62f857ee 100644 --- a/mindspore/python/mindspore/train/amp.py +++ b/mindspore/python/mindspore/train/amp.py @@ -99,7 +99,11 @@ def _insert_cast_operator(stree): if node.get_targets() is None: continue in_white_list = False - if node.get_node_type() != ms.rewrite.NodeType.Tree: + if node.get_node_type() == ms.rewrite.NodeType.CellContainer: + for n in node.get_handler().node_list: + if n.get_node_type() == ms.rewrite.NodeType.Tree: + _insert_cast_operator(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n))) + elif node.get_node_type() != ms.rewrite.NodeType.Tree: # insert cast before the primitive operators in white_list if node.get_instance_type() in AMP_WHITE_LIST_OPS: in_white_list = True @@ -164,7 +168,11 @@ def _remove_duplicated_cast(stree): for node in stree.nodes(): if node.get_targets() is None: continue - if node.get_node_type() != ms.rewrite.NodeType.Tree: + if node.get_node_type() == ms.rewrite.NodeType.CellContainer: + for n in node.get_handler().node_list: + if n.get_node_type() == ms.rewrite.NodeType.Tree: + _remove_duplicated_cast(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n))) + elif node.get_node_type() != ms.rewrite.NodeType.Tree: if node.get_instance_type() == P.Cast and _removed_cast_pair(node): # remove the following cast node first len_users = len(node.get_users()) diff --git a/tests/ut/python/rewrite/test_cellcontainer.py b/tests/ut/python/rewrite/test_cellcontainer.py new file mode 100644 index 00000000000..71d9c0c306b --- /dev/null +++ b/tests/ut/python/rewrite/test_cellcontainer.py @@ -0,0 +1,402 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cell container.""" + +from mindspore import nn +from mindspore.ops import operations as P + +from mindspore.rewrite import SymbolTree, NodeType, TreeNodeHelper, Node, ScopedValue, PatternEngine, Replacement, \ + PatternNode + + +def _conv3x3(in_channel, out_channel, stride=1): + return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, + padding=0, pad_mode='same', weight_init="ones") + + +def _conv1x1(in_channel, out_channel, stride=1): + return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, + padding=0, pad_mode='same', weight_init="ones") + + +def _bn(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, + gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +class ResidualBlock(nn.Cell): + expansion = 4 + + def __init__(self, + in_channel, + out_channel, + stride=1): + super(ResidualBlock, self).__init__() + self.stride = stride + channel = out_channel // self.expansion + self.conv1 = _conv1x1(in_channel, channel, stride=1) + self.bn1 = _bn(channel) + self.conv2 = _conv3x3(channel, channel, stride=stride) + self.bn2 = _bn(channel) + self.conv3 = _conv1x1(channel, out_channel, stride=1) + self.bn3 = _bn(out_channel) + self.relu = nn.ReLU() + self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), _bn(out_channel)]) + + def construct(self, x): + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + identity = self.down_sample_layer(identity) + out = out + identity + out = self.relu(out) + + return out + + +class ResNetSimple(nn.Cell): + def __init__(self): + super(ResNetSimple, self).__init__(auto_prefix=True) + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad', weight_init="ones") + self.bn1 = _bn(16) + self.relu = P.ReLU() + self.layer1 = self._make_layer(ResidualBlock, 3, in_channel=63, out_channel=256, stride=1) + self.layer1.append(self.conv1) + self.layer1.append(self.bn1) + self.reshape = P.Reshape() + self.out_channels = 10 + + def construct(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.layer1(x) + return x + + def _make_layer(self, block, layer_num, in_channel, out_channel, stride): + layers = [] + resnet_block = block(in_channel, out_channel, stride=stride) + layers.append(resnet_block) + for _ in range(1, layer_num): + resnet_block = ResidualBlock(out_channel, out_channel, stride=1) + layers.append(resnet_block) + return nn.SequentialCell(layers) + + +def test_cellcontainer_parse(): + """ + Feature: parse CellContainer Node. + Description: parse a network with SequentialCell object. + Expectation: Rewrite can parse a network with SquentialCell object successfully. + """ + net = ResNetSimple() + stree = SymbolTree.create(net) + for node in stree.nodes(): + if node.get_node_type() == NodeType.CellContainer: + assert len(node.get_handler().node_list) == 5 + for i, n in enumerate(node.get_handler().node_list): + if i < 3: + assert n.get_instance_type() is ResidualBlock + if i == 3: + assert n.get_instance_type() is nn.Conv2d + if i == 4: + assert n.get_instance_type() is nn.BatchNorm2d + + +def test_cellcontainer_insert(): + """ + Feature: modify CellContainer Node. + Description: using node in container to set insert location. + Expectation: raise ValueError. + """ + net = ResNetSimple() + stree = SymbolTree.create(net) + for node in stree.nodes(): + if node.get_node_type() == NodeType.CellContainer: + assert len(node.get_handler().nodes()) == 5 + for n in node.get_handler().nodes(): + if n.get_instance_type() is nn.Conv2d: + position = stree.before(Node(n)) + new_conv = nn.Conv2d(16, 16, 3) + new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv', + args=[ScopedValue.create_naming_value('self_max_po')]) + stree.insert(position, new_conv_node) + break + assert len(node.get_handler().nodes()) == 6 + assert node.get_handler().node_list[3].get_name() == "new_conv" + + +def test_cellcontainer_insert_ok(): + """ + Feature: modify CellContainer Node. + Description: Inserts a node within a tree node in CellContainer Node. + Expectation: Insertion succeeded. + """ + def _insert_conv(stree: SymbolTree): + for node in stree.nodes(): + if node.get_instance_type() == nn.BatchNorm2d: + position = stree.after(node) + new_conv = nn.Conv2d(16, 16, 3) + new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv', + args=[ScopedValue.create_naming_value('self_max_po')]) + stree.insert(position, new_conv_node) + break + net = ResNetSimple() + stree = SymbolTree.create(net) + for node in stree.nodes(): + if node.get_node_type() == NodeType.CellContainer: + for n in node.get_handler().node_list: + if n.get_node_type() == NodeType.Tree: + _insert_conv(TreeNodeHelper.get_sub_tree(Node(n))) + break + new_net = stree.get_network() + cell_container = getattr(new_net, "layer1") + assert hasattr(cell_container._cells["0"], "new_conv") + + +def test_cellcontainer_insert_to_subtree(): + """ + Feature: modify CellContainer Node. + Description: Inserts a node within a tree node in CellContainer Node. + Expectation: Insertion succeeded. + """ + def _insert_conv(stree: SymbolTree): + for node in stree.nodes(): + if node.get_instance_type() == nn.BatchNorm2d: + position = stree.after(node) + new_conv = nn.Conv2d(16, 16, 3) + new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv', + args=[ScopedValue.create_naming_value('self_max_po')]) + stree.insert(position, new_conv_node) + break + net = ResNetSimple() + stree = SymbolTree.create(net) + for node in stree.nodes(): + if node.get_node_type() == NodeType.CellContainer: + for n in node.get_handler().node_list: + if n.get_node_type() == NodeType.Tree: + _insert_conv(TreeNodeHelper.get_sub_tree(Node(n))) + break + new_net = stree.get_network() + cell_container = getattr(new_net, "layer1") + assert hasattr(cell_container._cells["0"], "new_conv") + + +def test_cellcontainer_del(): + """ + Feature: modify CellContainer Node. + Description: delete the CellContainer Node. + Expectation: success. + """ + net = ResNetSimple() + stree = SymbolTree.create(net) + original_nodes_size = len(stree.get_handler()._nodes) + for node in stree.nodes(): + if node.get_node_type() == NodeType.CellContainer and node.get_name() == "layer1": + users = node.get_users() + for user in users: + user.set_arg(0, "x") + stree.erase_node(node) + assert len(stree.get_handler()._nodes) == original_nodes_size - 1 + + +def test_cellcontainer_del_node(): + """ + Feature: modify CellContainer Node. + Description: delete the CellContainer Node. + Expectation: success. + """ + net = ResNetSimple() + stree = SymbolTree.create(net) + for node in stree.nodes(): + if node.get_node_type() == NodeType.CellContainer and node.get_name() == "layer1": + assert len(node.get_handler().nodes()) == 5 + for n in node.get_handler().nodes(): + users = node.get_users() + inputs = node.get_inputs() + for user in users: + user.set_arg_by_node(0, inputs[0]) + stree.erase_node(Node(n)) + break + assert len(node.get_handler().nodes()) == 4 + + +def test_cellcontainer_del_node_in_subtree(): + """ + Feature: modify CellContainer Node. + Description: delete a node within a tree node in CellContainer Node. + Expectation: success. + """ + def _del_node(sub_tree): + for _node in sub_tree.nodes(): + if _node.get_name() == "conv2": + users = Node(_node).get_users() + for user in users: + user.set_arg(0, "out") + sub_tree.erase_node(_node) + net = ResNetSimple() + stree = SymbolTree.create(net) + for node in stree.nodes(): + if node.get_node_type() == NodeType.CellContainer: + for i, n in enumerate(node.get_handler().node_list): + if n.get_node_type() == NodeType.Tree and i == 1: + sub_tree = n.symbol_tree + original_nodes_size = len(sub_tree._nodes) + _del_node(sub_tree) + assert len(sub_tree._nodes) == original_nodes_size - 1 + + new_net = stree.get_network() + cell_container = getattr(new_net, "layer1") + assert not hasattr(cell_container._cells["1"], "conv2") + + +def test_cellcontainer_replace(): + """ + Feature: modify CellContainer Node. + Description: replace CellContainer Node with another Node. + Expectation: success. + """ + def _replace_bn(stree: SymbolTree): + for node in stree.nodes(): + if node.get_node_type() == NodeType.CellContainer: + new_conv = nn.Conv2d(16, 16, 3) + new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv', + args=[ScopedValue.create_naming_value('x')]) + stree.replace(node, [new_conv_node]) + break + net = ResNetSimple() + stree = SymbolTree.create(net) + _replace_bn(stree) + new_net = stree.get_network() + assert not hasattr(new_net, "layer1") + assert hasattr(new_net, "new_conv") + + +def test_cellcontainer_replace_node(): + """ + Feature: modify CellContainer Node. + Description: replace the CellContainer Node. + Expectation: success. + """ + net = ResNetSimple() + stree = SymbolTree.create(net) + for node in stree.nodes(): + if node.get_node_type() == NodeType.CellContainer and node.get_name() == "layer1": + for n in node.get_handler().nodes(): + new_conv = nn.Conv2d(16, 16, 3) + new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv', + args=[ScopedValue.create_naming_value('x')]) + stree.replace(Node(n), [new_conv_node]) + break + assert node.get_handler().node_list[0].get_name() == "new_conv" + assert isinstance(node.get_handler().get_instance()._cells["0"], nn.Conv2d) + break + + +def test_cellcontainer_replace_in_subtree(): + """ + Feature: modify CellContainer Node. + Description: replace a node within a tree node in CellContainer Node. + Expectation: success. + """ + def _replace_bn(stree: SymbolTree): + for node in stree.nodes(): + if node.get_name() == "bn1": + new_conv = nn.Conv2d(16, 16, 3) + new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv', + args=[ScopedValue.create_naming_value('self_max_po')]) + stree.replace(node, [new_conv_node]) + break + net = ResNetSimple() + stree = SymbolTree.create(net) + for node in stree.nodes(): + if node.get_node_type() == NodeType.CellContainer: + for n in node.get_handler().node_list: + if n.get_node_type() == NodeType.Tree: + _replace_bn(TreeNodeHelper.get_sub_tree(Node(n))) + break + new_net = stree.get_network() + cell_container = getattr(new_net, "layer1") + assert not hasattr(cell_container._cells["0"], "bn1") + assert hasattr(cell_container._cells["0"], "new_conv") + + +def test_cellcontainer_pattern(): + """ + Feature: modify CellContainer Node. + Description: apply pattern matching and replacement on the network containing SequentialCell object. + Expectation: success. + """ + class ConvBnReplacement(Replacement): + def build(self, pattern: PatternNode, is_chain_pattern: bool, matched): + assert is_chain_pattern + assert pattern.type() == nn.BatchNorm2d + bn_node: Node = matched.get(pattern.name()) + assert bn_node is not None + assert len(pattern.get_inputs()) == 1 + add_pattern = pattern.get_inputs()[0] + assert add_pattern.type() == nn.Conv2d + add_node: Node = matched.get(add_pattern.name()) + assert add_node is not None + assert not add_pattern.get_inputs() + + new_maxpool1 = nn.MaxPool2d() + new_maxpool1_node = Node.create_call_cell(new_maxpool1, ['new_maxpool1'], add_node.get_args()) + new_relu1 = nn.ReLU() + new_relu1_node = Node.create_call_cell(new_relu1, ['new_relu_1'], + [ScopedValue.create_naming_value('new_maxpool1')]) + new_relu2 = nn.ReLU() + new_relu2_node = Node.create_call_cell(new_relu2, ['new_relu_2'], + [ScopedValue.create_naming_value('new_maxpool1')]) + new_maxpool2 = nn.BiDense(1, 1, 2) + new_maxpool2_node = Node.create_call_cell(new_maxpool2, ['new_maxpool2'], + [ScopedValue.create_naming_value('new_relu_1'), + ScopedValue.create_naming_value('new_relu_2')]) + return [new_maxpool1_node, new_relu1_node, new_relu2_node, new_maxpool2_node] + + + class ConvReluPattern(PatternEngine): + def __init__(self): + super().__init__([nn.Conv2d, nn.BatchNorm2d], ConvBnReplacement()) + + net = ResNetSimple() + stree = SymbolTree.create(net) + _pattern = ConvReluPattern() + _pattern.apply(stree) + new_net = stree.get_network() + cell_container = getattr(new_net, "layer1") + assert not hasattr(cell_container, "conv1") + assert not hasattr(cell_container, "bn1") + assert not hasattr(cell_container._cells["0"], "conv1") + assert not hasattr(cell_container._cells["1"], "conv1") + assert not hasattr(cell_container._cells["2"], "conv1") + assert hasattr(cell_container._cells["0"], "new_relu") + assert hasattr(cell_container._cells["0"], "new_maxpool1") + assert isinstance(getattr(getattr(cell_container._cells["0"], "down_sample_layer"), "0"), nn.MaxPool2d) + assert hasattr(cell_container._cells["1"], "new_relu") + assert hasattr(cell_container._cells["1"], "new_maxpool1") + assert isinstance(getattr(getattr(cell_container._cells["1"], "down_sample_layer"), "0"), nn.MaxPool2d) + assert hasattr(cell_container._cells["2"], "new_relu") + assert hasattr(cell_container._cells["2"], "new_maxpool1") + assert isinstance(getattr(getattr(cell_container._cells["2"], "down_sample_layer"), "0"), nn.MaxPool2d) + assert isinstance(getattr(cell_container, "3"), nn.MaxPool2d) + assert isinstance(getattr(cell_container, "4"), nn.ReLU) + assert isinstance(getattr(cell_container, "6"), nn.BiDense) diff --git a/tests/ut/python/rewrite/test_for.py b/tests/ut/python/rewrite/test_for.py index 0b6059b68b5..9c13bb223f0 100644 --- a/tests/ut/python/rewrite/test_for.py +++ b/tests/ut/python/rewrite/test_for.py @@ -2,6 +2,7 @@ from collections import OrderedDict from mindspore import nn +from mindspore.ops import operations as P from mindspore.rewrite import SymbolTree, PatternEngine, Replacement, PatternNode, Node, ScopedValue from mindspore.rewrite.api.tree_node_helper import TreeNodeHelper from mindspore.rewrite.api.node_type import NodeType @@ -108,6 +109,28 @@ class CellBlock(nn.Cell): return out +class SimpleNet(nn.Cell): + def __init__(self): + super().__init__() + self.mul = P.Mul() + self.dense = nn.Dense(in_channels=32, out_channels=32, weight_init="ones") + self.mean = P.ReduceMean(keep_dims=False) + self.split = P.Split(axis=1, output_num=3) + self.conv1 = nn.Conv2d(3, 64, 3, stride=2) + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.block = CellBlock(3, 6) + + def construct(self, x): + y, _, _ = self.split(x) + y = self.mean(y, (2, 3)) + x = self.mul(x, 1) + x = self.block(x) + x = self.conv1(x) + x = self.max_pool2d(x) + x = self.dense(x) + return x, y + + class ForNetWithSubTree(nn.Cell): def __init__(self): super(ForNetWithSubTree, self).__init__() @@ -125,12 +148,14 @@ class ForNetWithSubTree(nn.Cell): resnet_block3 = CellBlock(16, 32) layers = [resnet_block1, resnet_block2, resnet_block3] self.layer2 = nn.SequentialCell(layers) + self.simple_net = SimpleNet() def construct(self, x): x = self.conv1(x) x = self.layer1(x) x = self.relu(x) x = self.layer2(x) + x = self.simple_net(x) return x @@ -144,7 +169,7 @@ def test_erase_subtree_node(): stree = SymbolTree.create(net) for node in stree.nodes(): - if node.get_name() == "layer1": + if node.get_name() == "simple_net": subtree = TreeNodeHelper.get_sub_tree(node) orig_node_num = len(subtree.get_handler()._nodes) for n in subtree.nodes(): @@ -169,11 +194,11 @@ def test_erase_subtree_node_01(): stree = SymbolTree.create(net) for node in stree.nodes(): - if node.get_name() == "layer2": + if node.get_name() == "simple_net": subtree = TreeNodeHelper.get_sub_tree(node) orig_node_num = len(subtree.get_handler()._nodes) for n in subtree.nodes(): - if n.get_name() == "cell_list_1": + if n.get_name() == "block": input_node = n.get_inputs()[0] output_nodes = n.get_users() for _nn in output_nodes: @@ -203,10 +228,10 @@ def test_erase_subtree_node_02(): net = ForNetWithSubTree() stree = SymbolTree.create(net) for node in stree.nodes(): - if node.get_name() == "layer2": + if node.get_name() == "simple_net": subtree = TreeNodeHelper.get_sub_tree(node) for n in subtree.nodes(): - if n.get_name() == "cell_list_1": + if n.get_name() == "block": subtree1 = TreeNodeHelper.get_sub_tree(n) _remove_bn(subtree1) assert subtree1.get_node("bn1") is None @@ -231,10 +256,10 @@ def test_insert_subtree_node(): net = ForNetWithSubTree() stree = SymbolTree.create(net) for node in stree.nodes(): - if node.get_name() == "layer2" and node.get_node_type() == NodeType.Tree: + if node.get_name() == "simple_net" and node.get_node_type() == NodeType.Tree: subtree = TreeNodeHelper.get_sub_tree(node) for n in subtree.nodes(): - if n.get_name() == "cell_list_1": + if n.get_name() == "block": subtree1 = TreeNodeHelper.get_sub_tree(n) orig_node_num = len(subtree1.get_handler()._nodes) _insert_node(subtree1) @@ -251,7 +276,7 @@ def test_resnet_replace_121(): stree: SymbolTree = SymbolTree.create(net) original_nodes_size = len(stree.get_handler()._nodes) for node in stree.nodes(): - if node.get_name() == "layer1" and node.get_node_type() == NodeType.Tree: + if node.get_name() == "simple_net" and node.get_node_type() == NodeType.Tree: subtree = TreeNodeHelper.get_sub_tree(node) for n in subtree.nodes(): if n.get_instance_type() == nn.Conv2d: @@ -274,7 +299,7 @@ def test_resnet_replace_12m(): stree: SymbolTree = SymbolTree.create(net) for node in stree.nodes(): - if node.get_name() == "layer1" and node.get_node_type() == NodeType.Tree: + if node.get_name() == "simple_net" and node.get_node_type() == NodeType.Tree: subtree = TreeNodeHelper.get_sub_tree(node) original_nodes_size = len(subtree.get_handler()._nodes) for n in subtree.nodes(): @@ -301,7 +326,7 @@ def test_node_fusion_in_subtree(): stree: SymbolTree = SymbolTree.create(net) original_nodes_size = len(stree.get_handler()._nodes) for node in stree.nodes(): - if node.get_name() == "layer1" and node.get_node_type() == NodeType.Tree: + if node.get_name() == "simple_net" and node.get_node_type() == NodeType.Tree: subtree = TreeNodeHelper.get_sub_tree(node) original_nodes_size = len(subtree.get_handler()._nodes) for n in subtree.nodes(): diff --git a/tests/ut/python/rewrite/test_net_simple.py b/tests/ut/python/rewrite/test_net_simple.py index c6c6c2cc514..5722c03ecd6 100644 --- a/tests/ut/python/rewrite/test_net_simple.py +++ b/tests/ut/python/rewrite/test_net_simple.py @@ -133,7 +133,6 @@ def test_simple_net(): net = SimpleNet(10) stree = SymbolTree.create(net) transform(stree) - print("------------------------------------ keys of global_vars: ", stree.get_handler().get_global_vars().keys()) net_opt = stree.get_network() data_in = Tensor(np.ones([1, 1, 32, 32]), mindspore.float32) _cell_graph_executor.compile(net_opt, data_in)