!45091 fix ckpt param name bug
Merge pull request !45091 from 于振华/fix_rewrite_paramname_1101
This commit is contained in:
commit
88c93e8e28
|
@ -172,6 +172,7 @@
|
|||
"mindspore/tests/ut/python/mindir/test_mindir_export.py" "no-else-return"
|
||||
"mindspore/tests/" "c-extension-no-member"
|
||||
"mindspore/tests/st/parameter/test_parameter_celllist.py" "protected-access"
|
||||
"mindspore/tests/ut/python/rewrite/test_cellcontainer.py" "protected-access"
|
||||
|
||||
#MindSpore Lite
|
||||
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/HPC-generator/generator.py" "redefined-builtin"
|
||||
|
|
|
@ -279,6 +279,32 @@ class SequentialCell(Cell):
|
|||
input_data = cell(input_data)
|
||||
return input_data
|
||||
|
||||
def _insert(self, index, cell):
|
||||
"""
|
||||
Inserts a given Cell before a given index in the list.
|
||||
|
||||
Args:
|
||||
index(int): The Insert index in the CellList.
|
||||
cell(Cell): The Cell to be inserted.
|
||||
"""
|
||||
cls_name = self.__class__.__name__
|
||||
idx = _valid_index(len(self), index, cls_name)
|
||||
_valid_cell(cell, cls_name)
|
||||
length = len(self)
|
||||
prefix, key_index = _get_prefix_and_index(self._cells)
|
||||
while length > idx:
|
||||
if self._auto_prefix:
|
||||
tmp_cell = self._cells[str(length-1)]
|
||||
for _, param in tmp_cell.parameters_and_names():
|
||||
param.name = prefix + str(length) + "." + ".".join(param.name.split(".")[key_index+1:])
|
||||
self._cells[str(length)] = self._cells[str(length - 1)]
|
||||
length -= 1
|
||||
self._cells[str(idx)] = cell
|
||||
if self._auto_prefix:
|
||||
cell.update_parameters_name(prefix + str(idx) + ".")
|
||||
self.cell_list = list(self._cells.values())
|
||||
self._is_dynamic_name.insert(index, True)
|
||||
|
||||
|
||||
class CellList(_CellListBase, Cell):
|
||||
"""
|
||||
|
|
|
@ -43,3 +43,4 @@ class NodeType(Enum):
|
|||
Input = 7
|
||||
Output = 8
|
||||
Tree = 9
|
||||
CellContainer = 10
|
||||
|
|
|
@ -308,6 +308,16 @@ class PatternEngine:
|
|||
queue.extend(inputs_dict.get(cur_node.get_name()))
|
||||
return new_root
|
||||
|
||||
@staticmethod
|
||||
def _multi_replace_cellcontainer(stree, cellcontainer, node, matched_dict, new_nodes):
|
||||
"""Replace node in CellContainer."""
|
||||
to_erase_list = list(matched_dict.values())
|
||||
stree.replace(Node(node), new_nodes)
|
||||
for n in reversed(to_erase_list):
|
||||
if n.get_handler() is node:
|
||||
continue
|
||||
stree.erase_node(n)
|
||||
|
||||
def apply(self, stree: SymbolTree) -> bool:
|
||||
"""
|
||||
Apply current pattern to a `SymbolTree`.
|
||||
|
@ -359,6 +369,9 @@ class PatternEngine:
|
|||
visited.append(cur_node)
|
||||
queue.extend(cur_node.get_users())
|
||||
continue
|
||||
if cur_node.get_node_type() == NodeType.CellContainer:
|
||||
self._process_cellcontainer(stree, cur_node.get_handler())
|
||||
continue
|
||||
visited.append(cur_node)
|
||||
matched, matched_dict = self._match(self._pattern, cur_node)
|
||||
# not matched
|
||||
|
@ -460,3 +473,21 @@ class PatternEngine:
|
|||
logger.debug("Check match failed, pattern leaked")
|
||||
return False
|
||||
return True
|
||||
|
||||
def _process_cellcontainer(self, stree, cellcontainer):
|
||||
"""Process CellContainer node."""
|
||||
for node in cellcontainer.nodes():
|
||||
if node.get_node_type() == NodeType.Tree:
|
||||
subtree = node.symbol_tree
|
||||
self.apply(SymbolTree(subtree))
|
||||
continue
|
||||
else:
|
||||
matched, matched_dict = self._match(self._pattern, Node(node))
|
||||
if not matched:
|
||||
continue
|
||||
new_nodes = []
|
||||
if self._replacement is not None:
|
||||
new_nodes = self._replacement(self._pattern, self._is_chain, matched_dict)
|
||||
if not new_nodes: # if replacement is empty, do nothing
|
||||
continue
|
||||
PatternEngine._multi_replace_cellcontainer(stree, cellcontainer, node, matched_dict, new_nodes)
|
||||
|
|
|
@ -241,8 +241,10 @@ class AstModifier(ast.NodeTransformer):
|
|||
An instance of ast.Assign which has been appended to 'init_func'.
|
||||
"""
|
||||
return AstModifier.insert_assign_to_function(init_func, targets=targets,
|
||||
args=[ScopedValue.create_variable_value(field)],
|
||||
expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"))
|
||||
expr=ScopedValue(ValueType.NamingValue, "", "setattr"),
|
||||
args=[ScopedValue(ValueType.NamingValue, "obj"),
|
||||
ScopedValue.create_variable_value(field)])
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create_call_assign(targets: [ScopedValue], expr: ScopedValue, args: [ScopedValue],
|
||||
|
|
|
@ -24,8 +24,6 @@ _ms_functional_ns = CellNamespace('mindspore.ops.functional')
|
|||
|
||||
def is_subtree(cls_name):
|
||||
"""Determine whether 'cls_name' is a subtree."""
|
||||
if cls_name == "SequentialCell":
|
||||
return True
|
||||
if cls_name == "QuantizeWrapperCell":
|
||||
return False
|
||||
if cls_name in _ms_common_ns or cls_name in _ms_nn_ns or cls_name in _ms_ops_ns:
|
||||
|
|
|
@ -624,7 +624,7 @@ class Node:
|
|||
"""
|
||||
self._targets = targets
|
||||
if self._node_type in (NodeType.CallCell, NodeType.CallMethod, NodeType.CallPrimitive,
|
||||
NodeType.Tree, NodeType.CallFunction):
|
||||
NodeType.Tree, NodeType.CallFunction, NodeType.CellContainer):
|
||||
self._sync_assign_targets_to_ast()
|
||||
|
||||
def get_func(self) -> ScopedValue:
|
||||
|
@ -1135,7 +1135,7 @@ class Node:
|
|||
|
||||
def _sync_arg(self):
|
||||
"""Sync _normalized_args to corresponding ast node when updated."""
|
||||
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree):
|
||||
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree, NodeType.CellContainer):
|
||||
self._sync_call_cell_args_to_ast()
|
||||
elif self._node_type == NodeType.Output:
|
||||
self._sync_return_node_to_ast()
|
||||
|
@ -1198,3 +1198,85 @@ class TreeNode(Node):
|
|||
if ast_node is None:
|
||||
ast_node = AstModifier.create_call_assign(new_targets, func, non_custom_args, non_custom_kwargs)
|
||||
return cls(tree, ast_node, new_targets, func, args, kwargs, name, instance)
|
||||
|
||||
|
||||
class CellContainer(Node):
|
||||
""" Container for saving cell-objects node. """
|
||||
class _Visitor():
|
||||
""" A iterator of CellContainer nodes. """
|
||||
def __init__(self, cellcontainer):
|
||||
self._cellcontainer = cellcontainer
|
||||
|
||||
def __len__(self):
|
||||
""" Get the number of nodes. """
|
||||
return self._cellcontainer.node_count
|
||||
|
||||
def __iter__(self):
|
||||
"""Create an iterator over the CellContainer."""
|
||||
count = len(self._cellcontainer.node_list)
|
||||
i = 0
|
||||
while i < count:
|
||||
curr = self._cellcontainer.node_list[i]
|
||||
if curr.valid:
|
||||
yield curr
|
||||
i += 1
|
||||
|
||||
def __init__(self, ast_node: ast.AST, targets: [ScopedValue], func: ScopedValue,
|
||||
args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance):
|
||||
"""Constructor of CellContainer.
|
||||
|
||||
Args:
|
||||
ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
|
||||
targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
||||
func ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
|
||||
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
||||
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
|
||||
name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
|
||||
Name of node also used as field name in network class.
|
||||
instance: Object in network corresponding to this node.
|
||||
"""
|
||||
if isinstance(func, str):
|
||||
func = ScopedValue.create_naming_value(func)
|
||||
super().__init__(NodeType.CellContainer, ast_node, targets, func, args, kwargs, name, instance)
|
||||
self._node_list = list()
|
||||
self._node_count = 0
|
||||
|
||||
@property
|
||||
def node_count(self):
|
||||
"""Number of nodes."""
|
||||
return self._node_count
|
||||
|
||||
@node_count.setter
|
||||
def node_count(self, count):
|
||||
"""Set number of nodes."""
|
||||
self._node_count = count
|
||||
|
||||
@property
|
||||
def node_list(self):
|
||||
""" Get node list. """
|
||||
return self._node_list
|
||||
|
||||
def append(self, node):
|
||||
""" Append new node to node list. """
|
||||
self._node_list.append(node)
|
||||
self.get_instance().append(node.get_instance())
|
||||
self.node_count += 1
|
||||
|
||||
def erase(self, node):
|
||||
"""Erase node form container."""
|
||||
index = self.node_list.index(node)
|
||||
setattr(node, "valid", False)
|
||||
self.node_count -= 1
|
||||
index = self.get_instance().cell_list.index(node.get_instance())
|
||||
del self.get_instance()[index]
|
||||
|
||||
def insert(self, index, node):
|
||||
"""Insert node into container"""
|
||||
self.node_list.insert(index, node)
|
||||
setattr(node, "valid", True)
|
||||
self.get_instance()._insert(index, node.get_instance())
|
||||
self.node_count += 1
|
||||
|
||||
def nodes(self):
|
||||
""" Return a iterator of node."""
|
||||
return self._Visitor(self)
|
||||
|
|
|
@ -19,13 +19,13 @@ import astunparse
|
|||
|
||||
from mindspore import log as logger
|
||||
from mindspore._extends.parse.namespace import CellNamespace
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.nn import Cell, SequentialCell
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import Primitive
|
||||
from mindspore.rewrite.parser_register import ParserRegister
|
||||
from mindspore.rewrite.namespace import is_subtree, is_functional, get_functional
|
||||
from mindspore.rewrite.symbol_tree import SymbolTree
|
||||
from mindspore.rewrite.node import Node, TreeNode
|
||||
from mindspore.rewrite.node import Node, TreeNode, CellContainer
|
||||
from mindspore.rewrite.parser import Parser
|
||||
from mindspore.rewrite.parser_register import reg_parser
|
||||
from mindspore.rewrite.api.scoped_value import ScopedValue, ValueType
|
||||
|
@ -286,10 +286,10 @@ class AssignParser(Parser):
|
|||
if target.attr != func_name:
|
||||
continue
|
||||
changed = True
|
||||
global_vars_key = "_".join([func_name, "args"])
|
||||
stree.add_global_vars(global_vars_key, sub_tree.get_global_vars())
|
||||
args_call = AstModifier.create_call(ScopedValue.create_naming_value("get", "global_vars"),
|
||||
[ScopedValue.create_variable_value(global_vars_key)])
|
||||
setattr(stree.get_origin_network(), func_name, sub_tree.get_origin_network())
|
||||
args_call = AstModifier.create_call(ScopedValue(ValueType.NamingValue, "", "getattr"),
|
||||
[ScopedValue(ValueType.NamingValue, "", "obj"),
|
||||
ScopedValue(ValueType.StringValue, "", func_name)])
|
||||
body.value = ast.Call(func=ast.Name(class_name, ast.Store()), args=[args_call], keywords=[])
|
||||
break
|
||||
return changed
|
||||
|
@ -308,6 +308,37 @@ class AssignParser(Parser):
|
|||
call_args = [AssignParser._create_scopedvalue(arg) for arg in father_ast_node.value.args]
|
||||
return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, {})
|
||||
|
||||
def _cell_container_process(self, ast_node, stree, targets, func, call_args, call_kwargs, op_name, container_obj):
|
||||
""" parse cell container object."""
|
||||
cell_container = CellContainer(ast_node, targets, func, call_args, call_kwargs, op_name, container_obj)
|
||||
for i, cell in enumerate(container_obj):
|
||||
is_sub_tree = is_subtree(type(cell).__name__)
|
||||
if is_sub_tree:
|
||||
stb = SymbolTreeBuilder(cell)
|
||||
new_stree = stb.build()
|
||||
replacer = AstReplacer(new_stree.get_class_ast())
|
||||
replacer.replace_all(new_stree.get_ori_cls_name(), new_stree.get_opt_cls_name())
|
||||
tree = TreeNode.create_tree_node(new_stree, ast_node, targets, func, call_args, call_kwargs,
|
||||
type(cell).__name__, cell)
|
||||
setattr(tree, "container", cell_container)
|
||||
setattr(tree, "valid", True)
|
||||
tree.set_belong_symbol_tree(stree)
|
||||
cell_container.node_list.append(tree)
|
||||
cell_container.node_count += 1
|
||||
if i > 0:
|
||||
tree.set_inputs([cell_container.node_list[i-1]])
|
||||
else:
|
||||
node = Node.create_call_buildin_op(cell, ast_node, targets, func, call_args, call_kwargs,
|
||||
type(cell).__name__)
|
||||
setattr(node, "container", cell_container)
|
||||
setattr(node, "valid", True)
|
||||
node.set_belong_symbol_tree(stree)
|
||||
cell_container.node_list.append(node)
|
||||
cell_container.node_count += 1
|
||||
if i > 0:
|
||||
node.set_inputs([cell_container.node_list[i-1]])
|
||||
return cell_container
|
||||
|
||||
def _convert_ast_call_to_node(self, ast_node: ast.Call, father_ast_node: ast.Assign, stree: SymbolTree) -> Node:
|
||||
"""
|
||||
Convert ast.Call to a symbol tree node.
|
||||
|
@ -343,6 +374,10 @@ class AssignParser(Parser):
|
|||
return node
|
||||
raise RuntimeError(error_str(f"operator instance undefined.",
|
||||
child_node=ast_node.func, father_node=ast_node))
|
||||
if isinstance(op, SequentialCell):
|
||||
node = self._cell_container_process(father_ast_node, stree, targets, func, call_args, call_kwargs,
|
||||
func_name, op)
|
||||
return node
|
||||
if isinstance(op, Primitive):
|
||||
return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, call_kwargs, func_name)
|
||||
if isinstance(op, Cell):
|
||||
|
|
|
@ -21,8 +21,7 @@ from mindspore._extends.parse.namespace import CellNamespace
|
|||
from ..symbol_tree import SymbolTree
|
||||
from ..parser import Parser
|
||||
from ..parser_register import ParserRegister, reg_parser
|
||||
from ..api.scoped_value import ScopedValue
|
||||
from ..ast_helpers import AstReplacer, AstModifier
|
||||
from ..ast_helpers import AstReplacer
|
||||
from ..common import error_str
|
||||
|
||||
|
||||
|
@ -124,9 +123,6 @@ class ClassDefParser(Parser):
|
|||
super_index = ClassDefParser._find_super_expr_of_init_func(init_ast)
|
||||
ClassDefParser._modify_arguments_of_init_func(init_ast)
|
||||
self._replace_ori_field_of_init_func(stree, init_ast.body, super_index)
|
||||
# re-find super_index for init_func changed in _replace_ori_field_of_init_func
|
||||
super_index = ClassDefParser._find_super_expr_of_init_func(init_ast)
|
||||
ClassDefParser._insert_handler_to_init_func(init_ast, super_index)
|
||||
|
||||
@staticmethod
|
||||
def _find_super_expr_of_init_func(ast_init_fn: ast.FunctionDef) -> int:
|
||||
|
@ -158,7 +154,7 @@ class ClassDefParser(Parser):
|
|||
def _modify_arguments_of_init_func(ast_init_fn: ast.FunctionDef):
|
||||
"""Replace init function input parameters to self and global_vars."""
|
||||
arg_self = ast.arg(arg="self", annotation="")
|
||||
arg_global_vars = ast.arg(arg="global_vars", annotation="")
|
||||
arg_global_vars = ast.arg(arg="obj", annotation="")
|
||||
ast_init_fn.args = ast.arguments(args=[arg_self, arg_global_vars], posonlyargs=[], kwonlyargs=[],
|
||||
kw_defaults=[], defaults=[], vararg=None, kwarg=None)
|
||||
ast.fix_missing_locations(ast_init_fn)
|
||||
|
@ -235,22 +231,12 @@ class ClassDefParser(Parser):
|
|||
continue
|
||||
field_name = target.attr
|
||||
body.value = ast.Call(ast.Name('getattr', ast.Load()),
|
||||
[ast.Attribute(ast.Name('self', ast.Load()), '_handler', ast.Load()),
|
||||
[ast.Name('obj', ast.Load()),
|
||||
ast.Constant(value=field_name, kind=None)], [])
|
||||
for counter, index in enumerate(body_index_to_be_deleted):
|
||||
bodies.pop(index - counter)
|
||||
ClassDefParser._remove_empty_ast_in_init_func(bodies)
|
||||
|
||||
@staticmethod
|
||||
def _insert_handler_to_init_func(ast_init_fn: ast.FunctionDef, super_index):
|
||||
"""Insert 'self._handler = global_vars.get('handler')' to init ast.FunctionDef.body"""
|
||||
if super_index == -1:
|
||||
super_index = 0
|
||||
AstModifier.insert_assign_to_function(ast_init_fn, [ScopedValue.create_naming_value("_handler", "self")],
|
||||
ScopedValue.create_naming_value("get", "global_vars"),
|
||||
[ScopedValue.create_variable_value("handler")], None,
|
||||
ast_init_fn.body[super_index], False)
|
||||
|
||||
def process(self, stree: SymbolTree, node: ast.ClassDef):
|
||||
"""
|
||||
Parse init and construct in ast.ClassDef.
|
||||
|
|
|
@ -34,12 +34,13 @@ class ForParser(Parser):
|
|||
def modify_init_ast(stree, i, obj, iter_var_name):
|
||||
"""Modify the ast node in init function."""
|
||||
target = f"{iter_var_name.strip()}_{str(i)}"
|
||||
stree.add_global_vars(target, obj)
|
||||
setattr(stree.get_origin_network(), target, obj)
|
||||
stree.get_origin_network().insert_child_to_cell(target, obj)
|
||||
AstModifier.insert_assign_to_function(stree.get_init_func_ast(),
|
||||
targets=[ScopedValue(ValueType.NamingValue, "self", target)],
|
||||
expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"),
|
||||
args=[ScopedValue(ValueType.StringValue, "", target)])
|
||||
expr=ScopedValue(ValueType.NamingValue, "", "getattr"),
|
||||
args=[ScopedValue(ValueType.NamingValue, "", "obj"),
|
||||
ScopedValue(ValueType.StringValue, "", target)])
|
||||
|
||||
@staticmethod
|
||||
def modify_construct_ast(stree, ast_node, old_name, new_name):
|
||||
|
|
|
@ -44,7 +44,7 @@ class FunctionDefParser(Parser):
|
|||
else:
|
||||
parser.process(stree, body)
|
||||
|
||||
for body in node.body:
|
||||
for body in node.body[::-1]:
|
||||
if isinstance(body, (ast.For, ast.If)):
|
||||
node.body.remove(body)
|
||||
if hasattr(node, "decorator_list"):
|
||||
|
|
|
@ -160,7 +160,6 @@ class SymbolTree(Observer, Observable):
|
|||
self._topo_mgr = TopoManager()
|
||||
self._topo_mgr.reg_observer(self)
|
||||
|
||||
self._global_vars: {str, object} = {origin_network_key: origin_network}
|
||||
self._nodes: {str, Node} = {}
|
||||
# parameters of forward method
|
||||
self._inputs: [Node] = []
|
||||
|
@ -484,17 +483,6 @@ class SymbolTree(Observer, Observable):
|
|||
"""
|
||||
return self._origin_network
|
||||
|
||||
def get_global_vars(self):
|
||||
"""Get global variables."""
|
||||
return self._global_vars
|
||||
|
||||
def add_global_vars(self, key: str, value):
|
||||
"""Add global variables."""
|
||||
if self._global_vars.get(key) is not None:
|
||||
logger.info(f"The key '{key}' is duplicated")
|
||||
return
|
||||
self._global_vars[key] = value
|
||||
|
||||
def get_nodes_dict(self):
|
||||
"""Get dict of nodes"""
|
||||
return self._nodes
|
||||
|
@ -614,7 +602,6 @@ class SymbolTree(Observer, Observable):
|
|||
RuntimeError: If 'node_or_name' is not belong to this SymbolTree or any sub-SymbolTree of current
|
||||
SymbolTree.
|
||||
"""
|
||||
|
||||
node = self._get_real_node(node_or_name)
|
||||
if node is None:
|
||||
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
|
||||
|
@ -653,7 +640,12 @@ class SymbolTree(Observer, Observable):
|
|||
RuntimeError: If 'position' is not in current SymbolTree.
|
||||
RuntimeError: If corresponding ast node is not an ast.Assign when 'insert_to_ast' is True.
|
||||
"""
|
||||
|
||||
if position is not None and hasattr(position.node, "container"):
|
||||
cellcontainer = getattr(position.node, "container")
|
||||
index = cellcontainer.node_list.index(position.node)
|
||||
index = index if position.before_node else index + 1
|
||||
cellcontainer.insert(index, node)
|
||||
return node
|
||||
# if position in current SymbolTree
|
||||
if position is not None and position.symbol_tree is not self:
|
||||
raise RuntimeError("Position is not in current SymbolTree:", position)
|
||||
|
@ -683,10 +675,10 @@ class SymbolTree(Observer, Observable):
|
|||
if not isinstance(node_ast, ast.Assign):
|
||||
raise RuntimeError("Only support insert cell op now")
|
||||
if isinstance(node, TreeNode):
|
||||
global_vars_key = node.get_name() + "_args"
|
||||
self.add_global_vars(global_vars_key, node.symbol_tree.get_global_vars())
|
||||
args_call = AstModifier.create_call(ScopedValue.create_naming_value("get", "global_vars"),
|
||||
[ScopedValue.create_variable_value(global_vars_key)])
|
||||
setattr(self._origin_network, node.get_name(), node.get_instance())
|
||||
args_call = AstModifier.create_call(ScopedValue(ValueType.NamingValue, "", "getattr"),
|
||||
[ScopedValue(ValueType.NamingValue, "", "obj"),
|
||||
ScopedValue(ValueType.StringValue, "", node.get_name())])
|
||||
value = ast.Call(func=ast.Name(node.symbol_tree.get_opt_cls_name(), ast.Store(), lineno=0,
|
||||
col_offset=0), args=[args_call], keywords=[], lineno=0, col_offset=0)
|
||||
|
||||
|
@ -703,12 +695,13 @@ class SymbolTree(Observer, Observable):
|
|||
else:
|
||||
AstModifier.insert_assign_to_function(self._init_func_ast,
|
||||
targets=[ScopedValue(ValueType.NamingValue, "self", node_name)],
|
||||
expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"),
|
||||
args=[ScopedValue(ValueType.StringValue, "", node_name)])
|
||||
expr=ScopedValue(ValueType.NamingValue, "", "getattr"),
|
||||
args=[ScopedValue(ValueType.NamingValue, "", "obj"),
|
||||
ScopedValue(ValueType.StringValue, "", node_name)])
|
||||
AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
|
||||
None if position is None else position.node.get_ast(),
|
||||
position.before_node)
|
||||
self._global_vars[node_name] = node.get_instance()
|
||||
setattr(self._origin_network, node_name, node.get_instance())
|
||||
return node
|
||||
|
||||
def append_node(self, node: Node, append_to_ast: bool = True) -> Node:
|
||||
|
@ -851,6 +844,10 @@ class SymbolTree(Observer, Observable):
|
|||
node = self._get_real_node(node_or_name)
|
||||
if node is None:
|
||||
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
|
||||
if hasattr(node, "container"):
|
||||
cellcontainer = getattr(node, "container")
|
||||
cellcontainer.erase(node)
|
||||
return node
|
||||
ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
|
||||
if not ret:
|
||||
raise RuntimeError("node not in function ast tree.")
|
||||
|
@ -884,6 +881,9 @@ class SymbolTree(Observer, Observable):
|
|||
RuntimeError: If 'old_node' is not belong to current SymbolTree.
|
||||
"""
|
||||
|
||||
if hasattr(old_node, "container"):
|
||||
self._replace_container_node(old_node, new_nodes)
|
||||
return new_nodes[0]
|
||||
real_old_node = self._get_real_node(old_node)
|
||||
if real_old_node is None:
|
||||
raise RuntimeError("Old node is not belong to current SymbolTree:", old_node)
|
||||
|
@ -1026,7 +1026,7 @@ class SymbolTree(Observer, Observable):
|
|||
A network object.
|
||||
"""
|
||||
cls = self._get_cls_through_file()
|
||||
return cls(self._global_vars)
|
||||
return cls(self._origin_network)
|
||||
|
||||
def set_saved_file_name(self, file_name: str):
|
||||
"""Sets the filename used to save the network."""
|
||||
|
@ -1070,6 +1070,14 @@ class SymbolTree(Observer, Observable):
|
|||
else:
|
||||
body.names.remove(alias)
|
||||
|
||||
def _replace_container_node(self, old_node, new_nodes):
|
||||
cellcontainer = getattr(old_node, "container")
|
||||
index = cellcontainer.node_list.index(old_node)
|
||||
for n in reversed(new_nodes):
|
||||
cellcontainer.insert(index, n)
|
||||
index = cellcontainer.node_list.index(old_node)
|
||||
cellcontainer.erase(old_node)
|
||||
|
||||
def _filter_out_to_delete_field(self, to_delete_field):
|
||||
"""filter out used field from `to_delete_field`"""
|
||||
# filter _handler field
|
||||
|
@ -1077,7 +1085,8 @@ class SymbolTree(Observer, Observable):
|
|||
to_delete_field.pop("_handler")
|
||||
# filter field used in node of construct
|
||||
for node in self._nodes.values():
|
||||
if node.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree):
|
||||
if node.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree,
|
||||
NodeType.CellContainer):
|
||||
func: ScopedValue = node.get_func()
|
||||
if func.scope == "self" and to_delete_field.get(func.value):
|
||||
to_delete_field.pop(func.value)
|
||||
|
@ -1144,12 +1153,9 @@ class SymbolTree(Observer, Observable):
|
|||
self._module_ast.body.remove(body)
|
||||
|
||||
def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
|
||||
if isinstance(node_or_name, Node):
|
||||
result = self.get_node(node_or_name.get_name())
|
||||
return result if result is node_or_name else None
|
||||
if isinstance(node_or_name, str):
|
||||
return self.get_node(node_or_name)
|
||||
return None
|
||||
return node_or_name
|
||||
|
||||
def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node:
|
||||
"""
|
||||
|
@ -1298,7 +1304,7 @@ class SymbolTree(Observer, Observable):
|
|||
raise TypeError("value should be ScopedValue, got: ", type(value))
|
||||
if value.type == ValueType.CustomObjValue:
|
||||
field = self._node_name_namer.get_name(f"var_{type(value.value).__name__}")
|
||||
self._global_vars[field] = value.value
|
||||
setattr(self._origin_network, field, value.value)
|
||||
init_targets = [ScopedValue.create_naming_value(field, "self")]
|
||||
AstModifier.append_global_vars_expr_to_init(self._init_func_ast, init_targets, field)
|
||||
result[arg] = init_targets[0]
|
||||
|
@ -1316,7 +1322,8 @@ class SymbolTree(Observer, Observable):
|
|||
Returns:
|
||||
A class handle.
|
||||
"""
|
||||
file_name = "new_network_{0}.py".format(int(time.time() * 10000))
|
||||
self._update_container()
|
||||
file_name = "new_network_{0}.py".format(int(time.time() * 10000000))
|
||||
with os.fdopen(os.open(file_name, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
|
||||
source = self.get_code()
|
||||
f.write(source.encode('utf-8'))
|
||||
|
@ -1333,3 +1340,18 @@ class SymbolTree(Observer, Observable):
|
|||
def _on_change(self, event: Event):
|
||||
self._modified = True
|
||||
self.changed(event)
|
||||
|
||||
def _update_container(self):
|
||||
"""Update instance of node in container."""
|
||||
for node in self.nodes():
|
||||
index = 0
|
||||
if node.get_node_type() == NodeType.CellContainer:
|
||||
for n in node.node_list:
|
||||
if not n.valid:
|
||||
continue
|
||||
if n.get_node_type() == NodeType.Tree:
|
||||
obj = n.symbol_tree.get_network()
|
||||
node.get_instance()[index] = obj
|
||||
else:
|
||||
node.get_instance()[index] = n.get_instance()
|
||||
index += 1
|
||||
|
|
|
@ -100,7 +100,11 @@ def _insert_cast_operator(stree):
|
|||
if node.get_targets() is None:
|
||||
continue
|
||||
in_white_list = False
|
||||
if node.get_node_type() != ms.rewrite.NodeType.Tree:
|
||||
if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
|
||||
for n in node.get_handler().node_list:
|
||||
if n.get_node_type() == ms.rewrite.NodeType.Tree:
|
||||
_insert_cast_operator(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)))
|
||||
elif node.get_node_type() != ms.rewrite.NodeType.Tree:
|
||||
# insert cast before the primitive operators in white_list
|
||||
if node.get_instance_type() in AMP_WHITE_LIST_OPS:
|
||||
in_white_list = True
|
||||
|
@ -165,7 +169,11 @@ def _remove_duplicated_cast(stree):
|
|||
for node in stree.nodes():
|
||||
if node.get_targets() is None:
|
||||
continue
|
||||
if node.get_node_type() != ms.rewrite.NodeType.Tree:
|
||||
if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
|
||||
for n in node.get_handler().node_list:
|
||||
if n.get_node_type() == ms.rewrite.NodeType.Tree:
|
||||
_remove_duplicated_cast(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)))
|
||||
elif node.get_node_type() != ms.rewrite.NodeType.Tree:
|
||||
if node.get_instance_type() == P.Cast and _removed_cast_pair(node):
|
||||
# remove the following cast node first
|
||||
len_users = len(node.get_users())
|
||||
|
|
|
@ -0,0 +1,402 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test cell container."""
|
||||
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
from mindspore.rewrite import SymbolTree, NodeType, TreeNodeHelper, Node, ScopedValue, PatternEngine, Replacement, \
|
||||
PatternNode
|
||||
|
||||
|
||||
def _conv3x3(in_channel, out_channel, stride=1):
|
||||
return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
|
||||
padding=0, pad_mode='same', weight_init="ones")
|
||||
|
||||
|
||||
def _conv1x1(in_channel, out_channel, stride=1):
|
||||
return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
|
||||
padding=0, pad_mode='same', weight_init="ones")
|
||||
|
||||
|
||||
def _bn(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
|
||||
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Cell):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.stride = stride
|
||||
channel = out_channel // self.expansion
|
||||
self.conv1 = _conv1x1(in_channel, channel, stride=1)
|
||||
self.bn1 = _bn(channel)
|
||||
self.conv2 = _conv3x3(channel, channel, stride=stride)
|
||||
self.bn2 = _bn(channel)
|
||||
self.conv3 = _conv1x1(channel, out_channel, stride=1)
|
||||
self.bn3 = _bn(out_channel)
|
||||
self.relu = nn.ReLU()
|
||||
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), _bn(out_channel)])
|
||||
|
||||
def construct(self, x):
|
||||
identity = x
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
identity = self.down_sample_layer(identity)
|
||||
out = out + identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetSimple(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ResNetSimple, self).__init__(auto_prefix=True)
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad', weight_init="ones")
|
||||
self.bn1 = _bn(16)
|
||||
self.relu = P.ReLU()
|
||||
self.layer1 = self._make_layer(ResidualBlock, 3, in_channel=63, out_channel=256, stride=1)
|
||||
self.layer1.append(self.conv1)
|
||||
self.layer1.append(self.bn1)
|
||||
self.reshape = P.Reshape()
|
||||
self.out_channels = 10
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.layer1(x)
|
||||
return x
|
||||
|
||||
def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
|
||||
layers = []
|
||||
resnet_block = block(in_channel, out_channel, stride=stride)
|
||||
layers.append(resnet_block)
|
||||
for _ in range(1, layer_num):
|
||||
resnet_block = ResidualBlock(out_channel, out_channel, stride=1)
|
||||
layers.append(resnet_block)
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
|
||||
def test_cellcontainer_parse():
|
||||
"""
|
||||
Feature: parse CellContainer Node.
|
||||
Description: parse a network with SequentialCell object.
|
||||
Expectation: Rewrite can parse a network with SquentialCell object successfully.
|
||||
"""
|
||||
net = ResNetSimple()
|
||||
stree = SymbolTree.create(net)
|
||||
for node in stree.nodes():
|
||||
if node.get_node_type() == NodeType.CellContainer:
|
||||
assert len(node.get_handler().node_list) == 5
|
||||
for i, n in enumerate(node.get_handler().node_list):
|
||||
if i < 3:
|
||||
assert n.get_instance_type() is ResidualBlock
|
||||
if i == 3:
|
||||
assert n.get_instance_type() is nn.Conv2d
|
||||
if i == 4:
|
||||
assert n.get_instance_type() is nn.BatchNorm2d
|
||||
|
||||
|
||||
def test_cellcontainer_insert():
|
||||
"""
|
||||
Feature: modify CellContainer Node.
|
||||
Description: using node in container to set insert location.
|
||||
Expectation: raise ValueError.
|
||||
"""
|
||||
net = ResNetSimple()
|
||||
stree = SymbolTree.create(net)
|
||||
for node in stree.nodes():
|
||||
if node.get_node_type() == NodeType.CellContainer:
|
||||
assert len(node.get_handler().nodes()) == 5
|
||||
for n in node.get_handler().nodes():
|
||||
if n.get_instance_type() is nn.Conv2d:
|
||||
position = stree.before(Node(n))
|
||||
new_conv = nn.Conv2d(16, 16, 3)
|
||||
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
|
||||
args=[ScopedValue.create_naming_value('self_max_po')])
|
||||
stree.insert(position, new_conv_node)
|
||||
break
|
||||
assert len(node.get_handler().nodes()) == 6
|
||||
assert node.get_handler().node_list[3].get_name() == "new_conv"
|
||||
|
||||
|
||||
def test_cellcontainer_insert_ok():
|
||||
"""
|
||||
Feature: modify CellContainer Node.
|
||||
Description: Inserts a node within a tree node in CellContainer Node.
|
||||
Expectation: Insertion succeeded.
|
||||
"""
|
||||
def _insert_conv(stree: SymbolTree):
|
||||
for node in stree.nodes():
|
||||
if node.get_instance_type() == nn.BatchNorm2d:
|
||||
position = stree.after(node)
|
||||
new_conv = nn.Conv2d(16, 16, 3)
|
||||
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
|
||||
args=[ScopedValue.create_naming_value('self_max_po')])
|
||||
stree.insert(position, new_conv_node)
|
||||
break
|
||||
net = ResNetSimple()
|
||||
stree = SymbolTree.create(net)
|
||||
for node in stree.nodes():
|
||||
if node.get_node_type() == NodeType.CellContainer:
|
||||
for n in node.get_handler().node_list:
|
||||
if n.get_node_type() == NodeType.Tree:
|
||||
_insert_conv(TreeNodeHelper.get_sub_tree(Node(n)))
|
||||
break
|
||||
new_net = stree.get_network()
|
||||
cell_container = getattr(new_net, "layer1")
|
||||
assert hasattr(cell_container._cells["0"], "new_conv")
|
||||
|
||||
|
||||
def test_cellcontainer_insert_to_subtree():
|
||||
"""
|
||||
Feature: modify CellContainer Node.
|
||||
Description: Inserts a node within a tree node in CellContainer Node.
|
||||
Expectation: Insertion succeeded.
|
||||
"""
|
||||
def _insert_conv(stree: SymbolTree):
|
||||
for node in stree.nodes():
|
||||
if node.get_instance_type() == nn.BatchNorm2d:
|
||||
position = stree.after(node)
|
||||
new_conv = nn.Conv2d(16, 16, 3)
|
||||
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
|
||||
args=[ScopedValue.create_naming_value('self_max_po')])
|
||||
stree.insert(position, new_conv_node)
|
||||
break
|
||||
net = ResNetSimple()
|
||||
stree = SymbolTree.create(net)
|
||||
for node in stree.nodes():
|
||||
if node.get_node_type() == NodeType.CellContainer:
|
||||
for n in node.get_handler().node_list:
|
||||
if n.get_node_type() == NodeType.Tree:
|
||||
_insert_conv(TreeNodeHelper.get_sub_tree(Node(n)))
|
||||
break
|
||||
new_net = stree.get_network()
|
||||
cell_container = getattr(new_net, "layer1")
|
||||
assert hasattr(cell_container._cells["0"], "new_conv")
|
||||
|
||||
|
||||
def test_cellcontainer_del():
|
||||
"""
|
||||
Feature: modify CellContainer Node.
|
||||
Description: delete the CellContainer Node.
|
||||
Expectation: success.
|
||||
"""
|
||||
net = ResNetSimple()
|
||||
stree = SymbolTree.create(net)
|
||||
original_nodes_size = len(stree.get_handler()._nodes)
|
||||
for node in stree.nodes():
|
||||
if node.get_node_type() == NodeType.CellContainer and node.get_name() == "layer1":
|
||||
users = node.get_users()
|
||||
for user in users:
|
||||
user.set_arg(0, "x")
|
||||
stree.erase_node(node)
|
||||
assert len(stree.get_handler()._nodes) == original_nodes_size - 1
|
||||
|
||||
|
||||
def test_cellcontainer_del_node():
|
||||
"""
|
||||
Feature: modify CellContainer Node.
|
||||
Description: delete the CellContainer Node.
|
||||
Expectation: success.
|
||||
"""
|
||||
net = ResNetSimple()
|
||||
stree = SymbolTree.create(net)
|
||||
for node in stree.nodes():
|
||||
if node.get_node_type() == NodeType.CellContainer and node.get_name() == "layer1":
|
||||
assert len(node.get_handler().nodes()) == 5
|
||||
for n in node.get_handler().nodes():
|
||||
users = node.get_users()
|
||||
inputs = node.get_inputs()
|
||||
for user in users:
|
||||
user.set_arg_by_node(0, inputs[0])
|
||||
stree.erase_node(Node(n))
|
||||
break
|
||||
assert len(node.get_handler().nodes()) == 4
|
||||
|
||||
|
||||
def test_cellcontainer_del_node_in_subtree():
|
||||
"""
|
||||
Feature: modify CellContainer Node.
|
||||
Description: delete a node within a tree node in CellContainer Node.
|
||||
Expectation: success.
|
||||
"""
|
||||
def _del_node(sub_tree):
|
||||
for _node in sub_tree.nodes():
|
||||
if _node.get_name() == "conv2":
|
||||
users = Node(_node).get_users()
|
||||
for user in users:
|
||||
user.set_arg(0, "out")
|
||||
sub_tree.erase_node(_node)
|
||||
net = ResNetSimple()
|
||||
stree = SymbolTree.create(net)
|
||||
for node in stree.nodes():
|
||||
if node.get_node_type() == NodeType.CellContainer:
|
||||
for i, n in enumerate(node.get_handler().node_list):
|
||||
if n.get_node_type() == NodeType.Tree and i == 1:
|
||||
sub_tree = n.symbol_tree
|
||||
original_nodes_size = len(sub_tree._nodes)
|
||||
_del_node(sub_tree)
|
||||
assert len(sub_tree._nodes) == original_nodes_size - 1
|
||||
|
||||
new_net = stree.get_network()
|
||||
cell_container = getattr(new_net, "layer1")
|
||||
assert not hasattr(cell_container._cells["1"], "conv2")
|
||||
|
||||
|
||||
def test_cellcontainer_replace():
|
||||
"""
|
||||
Feature: modify CellContainer Node.
|
||||
Description: replace CellContainer Node with another Node.
|
||||
Expectation: success.
|
||||
"""
|
||||
def _replace_bn(stree: SymbolTree):
|
||||
for node in stree.nodes():
|
||||
if node.get_node_type() == NodeType.CellContainer:
|
||||
new_conv = nn.Conv2d(16, 16, 3)
|
||||
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
|
||||
args=[ScopedValue.create_naming_value('x')])
|
||||
stree.replace(node, [new_conv_node])
|
||||
break
|
||||
net = ResNetSimple()
|
||||
stree = SymbolTree.create(net)
|
||||
_replace_bn(stree)
|
||||
new_net = stree.get_network()
|
||||
assert not hasattr(new_net, "layer1")
|
||||
assert hasattr(new_net, "new_conv")
|
||||
|
||||
|
||||
def test_cellcontainer_replace_node():
|
||||
"""
|
||||
Feature: modify CellContainer Node.
|
||||
Description: replace the CellContainer Node.
|
||||
Expectation: success.
|
||||
"""
|
||||
net = ResNetSimple()
|
||||
stree = SymbolTree.create(net)
|
||||
for node in stree.nodes():
|
||||
if node.get_node_type() == NodeType.CellContainer and node.get_name() == "layer1":
|
||||
for n in node.get_handler().nodes():
|
||||
new_conv = nn.Conv2d(16, 16, 3)
|
||||
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
|
||||
args=[ScopedValue.create_naming_value('x')])
|
||||
stree.replace(Node(n), [new_conv_node])
|
||||
break
|
||||
assert node.get_handler().node_list[0].get_name() == "new_conv"
|
||||
assert isinstance(node.get_handler().get_instance()._cells["0"], nn.Conv2d)
|
||||
break
|
||||
|
||||
|
||||
def test_cellcontainer_replace_in_subtree():
|
||||
"""
|
||||
Feature: modify CellContainer Node.
|
||||
Description: replace a node within a tree node in CellContainer Node.
|
||||
Expectation: success.
|
||||
"""
|
||||
def _replace_bn(stree: SymbolTree):
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "bn1":
|
||||
new_conv = nn.Conv2d(16, 16, 3)
|
||||
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
|
||||
args=[ScopedValue.create_naming_value('self_max_po')])
|
||||
stree.replace(node, [new_conv_node])
|
||||
break
|
||||
net = ResNetSimple()
|
||||
stree = SymbolTree.create(net)
|
||||
for node in stree.nodes():
|
||||
if node.get_node_type() == NodeType.CellContainer:
|
||||
for n in node.get_handler().node_list:
|
||||
if n.get_node_type() == NodeType.Tree:
|
||||
_replace_bn(TreeNodeHelper.get_sub_tree(Node(n)))
|
||||
break
|
||||
new_net = stree.get_network()
|
||||
cell_container = getattr(new_net, "layer1")
|
||||
assert not hasattr(cell_container._cells["0"], "bn1")
|
||||
assert hasattr(cell_container._cells["0"], "new_conv")
|
||||
|
||||
|
||||
def test_cellcontainer_pattern():
|
||||
"""
|
||||
Feature: modify CellContainer Node.
|
||||
Description: apply pattern matching and replacement on the network containing SequentialCell object.
|
||||
Expectation: success.
|
||||
"""
|
||||
class ConvBnReplacement(Replacement):
|
||||
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched):
|
||||
assert is_chain_pattern
|
||||
assert pattern.type() == nn.BatchNorm2d
|
||||
bn_node: Node = matched.get(pattern.name())
|
||||
assert bn_node is not None
|
||||
assert len(pattern.get_inputs()) == 1
|
||||
add_pattern = pattern.get_inputs()[0]
|
||||
assert add_pattern.type() == nn.Conv2d
|
||||
add_node: Node = matched.get(add_pattern.name())
|
||||
assert add_node is not None
|
||||
assert not add_pattern.get_inputs()
|
||||
|
||||
new_maxpool1 = nn.MaxPool2d()
|
||||
new_maxpool1_node = Node.create_call_cell(new_maxpool1, ['new_maxpool1'], add_node.get_args())
|
||||
new_relu1 = nn.ReLU()
|
||||
new_relu1_node = Node.create_call_cell(new_relu1, ['new_relu_1'],
|
||||
[ScopedValue.create_naming_value('new_maxpool1')])
|
||||
new_relu2 = nn.ReLU()
|
||||
new_relu2_node = Node.create_call_cell(new_relu2, ['new_relu_2'],
|
||||
[ScopedValue.create_naming_value('new_maxpool1')])
|
||||
new_maxpool2 = nn.BiDense(1, 1, 2)
|
||||
new_maxpool2_node = Node.create_call_cell(new_maxpool2, ['new_maxpool2'],
|
||||
[ScopedValue.create_naming_value('new_relu_1'),
|
||||
ScopedValue.create_naming_value('new_relu_2')])
|
||||
return [new_maxpool1_node, new_relu1_node, new_relu2_node, new_maxpool2_node]
|
||||
|
||||
|
||||
class ConvReluPattern(PatternEngine):
|
||||
def __init__(self):
|
||||
super().__init__([nn.Conv2d, nn.BatchNorm2d], ConvBnReplacement())
|
||||
|
||||
net = ResNetSimple()
|
||||
stree = SymbolTree.create(net)
|
||||
_pattern = ConvReluPattern()
|
||||
_pattern.apply(stree)
|
||||
new_net = stree.get_network()
|
||||
cell_container = getattr(new_net, "layer1")
|
||||
assert not hasattr(cell_container, "conv1")
|
||||
assert not hasattr(cell_container, "bn1")
|
||||
assert not hasattr(cell_container._cells["0"], "conv1")
|
||||
assert not hasattr(cell_container._cells["1"], "conv1")
|
||||
assert not hasattr(cell_container._cells["2"], "conv1")
|
||||
assert hasattr(cell_container._cells["0"], "new_relu")
|
||||
assert hasattr(cell_container._cells["0"], "new_maxpool1")
|
||||
assert isinstance(getattr(getattr(cell_container._cells["0"], "down_sample_layer"), "0"), nn.MaxPool2d)
|
||||
assert hasattr(cell_container._cells["1"], "new_relu")
|
||||
assert hasattr(cell_container._cells["1"], "new_maxpool1")
|
||||
assert isinstance(getattr(getattr(cell_container._cells["1"], "down_sample_layer"), "0"), nn.MaxPool2d)
|
||||
assert hasattr(cell_container._cells["2"], "new_relu")
|
||||
assert hasattr(cell_container._cells["2"], "new_maxpool1")
|
||||
assert isinstance(getattr(getattr(cell_container._cells["2"], "down_sample_layer"), "0"), nn.MaxPool2d)
|
||||
assert isinstance(getattr(cell_container, "3"), nn.MaxPool2d)
|
||||
assert isinstance(getattr(cell_container, "4"), nn.ReLU)
|
||||
assert isinstance(getattr(cell_container, "6"), nn.BiDense)
|
|
@ -2,6 +2,7 @@
|
|||
from collections import OrderedDict
|
||||
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.rewrite import SymbolTree, PatternEngine, Replacement, PatternNode, Node, ScopedValue
|
||||
from mindspore.rewrite.api.tree_node_helper import TreeNodeHelper
|
||||
from mindspore.rewrite.api.node_type import NodeType
|
||||
|
@ -108,6 +109,28 @@ class CellBlock(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
class SimpleNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mul = P.Mul()
|
||||
self.dense = nn.Dense(in_channels=32, out_channels=32, weight_init="ones")
|
||||
self.mean = P.ReduceMean(keep_dims=False)
|
||||
self.split = P.Split(axis=1, output_num=3)
|
||||
self.conv1 = nn.Conv2d(3, 64, 3, stride=2)
|
||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.block = CellBlock(3, 6)
|
||||
|
||||
def construct(self, x):
|
||||
y, _, _ = self.split(x)
|
||||
y = self.mean(y, (2, 3))
|
||||
x = self.mul(x, 1)
|
||||
x = self.block(x)
|
||||
x = self.conv1(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.dense(x)
|
||||
return x, y
|
||||
|
||||
|
||||
class ForNetWithSubTree(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ForNetWithSubTree, self).__init__()
|
||||
|
@ -125,12 +148,14 @@ class ForNetWithSubTree(nn.Cell):
|
|||
resnet_block3 = CellBlock(16, 32)
|
||||
layers = [resnet_block1, resnet_block2, resnet_block3]
|
||||
self.layer2 = nn.SequentialCell(layers)
|
||||
self.simple_net = SimpleNet()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.layer1(x)
|
||||
x = self.relu(x)
|
||||
x = self.layer2(x)
|
||||
x = self.simple_net(x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -144,7 +169,7 @@ def test_erase_subtree_node():
|
|||
stree = SymbolTree.create(net)
|
||||
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "layer1":
|
||||
if node.get_name() == "simple_net":
|
||||
subtree = TreeNodeHelper.get_sub_tree(node)
|
||||
orig_node_num = len(subtree.get_handler()._nodes)
|
||||
for n in subtree.nodes():
|
||||
|
@ -169,11 +194,11 @@ def test_erase_subtree_node_01():
|
|||
stree = SymbolTree.create(net)
|
||||
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "layer2":
|
||||
if node.get_name() == "simple_net":
|
||||
subtree = TreeNodeHelper.get_sub_tree(node)
|
||||
orig_node_num = len(subtree.get_handler()._nodes)
|
||||
for n in subtree.nodes():
|
||||
if n.get_name() == "cell_list_1":
|
||||
if n.get_name() == "block":
|
||||
input_node = n.get_inputs()[0]
|
||||
output_nodes = n.get_users()
|
||||
for _nn in output_nodes:
|
||||
|
@ -203,10 +228,10 @@ def test_erase_subtree_node_02():
|
|||
net = ForNetWithSubTree()
|
||||
stree = SymbolTree.create(net)
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "layer2":
|
||||
if node.get_name() == "simple_net":
|
||||
subtree = TreeNodeHelper.get_sub_tree(node)
|
||||
for n in subtree.nodes():
|
||||
if n.get_name() == "cell_list_1":
|
||||
if n.get_name() == "block":
|
||||
subtree1 = TreeNodeHelper.get_sub_tree(n)
|
||||
_remove_bn(subtree1)
|
||||
assert subtree1.get_node("bn1") is None
|
||||
|
@ -231,10 +256,10 @@ def test_insert_subtree_node():
|
|||
net = ForNetWithSubTree()
|
||||
stree = SymbolTree.create(net)
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "layer2" and node.get_node_type() == NodeType.Tree:
|
||||
if node.get_name() == "simple_net" and node.get_node_type() == NodeType.Tree:
|
||||
subtree = TreeNodeHelper.get_sub_tree(node)
|
||||
for n in subtree.nodes():
|
||||
if n.get_name() == "cell_list_1":
|
||||
if n.get_name() == "block":
|
||||
subtree1 = TreeNodeHelper.get_sub_tree(n)
|
||||
orig_node_num = len(subtree1.get_handler()._nodes)
|
||||
_insert_node(subtree1)
|
||||
|
@ -251,7 +276,7 @@ def test_resnet_replace_121():
|
|||
stree: SymbolTree = SymbolTree.create(net)
|
||||
original_nodes_size = len(stree.get_handler()._nodes)
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "layer1" and node.get_node_type() == NodeType.Tree:
|
||||
if node.get_name() == "simple_net" and node.get_node_type() == NodeType.Tree:
|
||||
subtree = TreeNodeHelper.get_sub_tree(node)
|
||||
for n in subtree.nodes():
|
||||
if n.get_instance_type() == nn.Conv2d:
|
||||
|
@ -274,7 +299,7 @@ def test_resnet_replace_12m():
|
|||
stree: SymbolTree = SymbolTree.create(net)
|
||||
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "layer1" and node.get_node_type() == NodeType.Tree:
|
||||
if node.get_name() == "simple_net" and node.get_node_type() == NodeType.Tree:
|
||||
subtree = TreeNodeHelper.get_sub_tree(node)
|
||||
original_nodes_size = len(subtree.get_handler()._nodes)
|
||||
for n in subtree.nodes():
|
||||
|
@ -301,7 +326,7 @@ def test_node_fusion_in_subtree():
|
|||
stree: SymbolTree = SymbolTree.create(net)
|
||||
original_nodes_size = len(stree.get_handler()._nodes)
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "layer1" and node.get_node_type() == NodeType.Tree:
|
||||
if node.get_name() == "simple_net" and node.get_node_type() == NodeType.Tree:
|
||||
subtree = TreeNodeHelper.get_sub_tree(node)
|
||||
original_nodes_size = len(subtree.get_handler()._nodes)
|
||||
for n in subtree.nodes():
|
||||
|
|
|
@ -133,7 +133,6 @@ def test_simple_net():
|
|||
net = SimpleNet(10)
|
||||
stree = SymbolTree.create(net)
|
||||
transform(stree)
|
||||
print("------------------------------------ keys of global_vars: ", stree.get_handler().get_global_vars().keys())
|
||||
net_opt = stree.get_network()
|
||||
data_in = Tensor(np.ones([1, 1, 32, 32]), mindspore.float32)
|
||||
_cell_graph_executor.compile(net_opt, data_in)
|
||||
|
|
Loading…
Reference in New Issue