!45091 fix ckpt param name bug

Merge pull request !45091 from 于振华/fix_rewrite_paramname_1101
This commit is contained in:
i-robot 2022-12-09 01:55:28 +00:00 committed by Gitee
commit 88c93e8e28
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
16 changed files with 694 additions and 75 deletions

View File

@ -172,6 +172,7 @@
"mindspore/tests/ut/python/mindir/test_mindir_export.py" "no-else-return"
"mindspore/tests/" "c-extension-no-member"
"mindspore/tests/st/parameter/test_parameter_celllist.py" "protected-access"
"mindspore/tests/ut/python/rewrite/test_cellcontainer.py" "protected-access"
#MindSpore Lite
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/HPC-generator/generator.py" "redefined-builtin"

View File

@ -279,6 +279,32 @@ class SequentialCell(Cell):
input_data = cell(input_data)
return input_data
def _insert(self, index, cell):
"""
Inserts a given Cell before a given index in the list.
Args:
index(int): The Insert index in the CellList.
cell(Cell): The Cell to be inserted.
"""
cls_name = self.__class__.__name__
idx = _valid_index(len(self), index, cls_name)
_valid_cell(cell, cls_name)
length = len(self)
prefix, key_index = _get_prefix_and_index(self._cells)
while length > idx:
if self._auto_prefix:
tmp_cell = self._cells[str(length-1)]
for _, param in tmp_cell.parameters_and_names():
param.name = prefix + str(length) + "." + ".".join(param.name.split(".")[key_index+1:])
self._cells[str(length)] = self._cells[str(length - 1)]
length -= 1
self._cells[str(idx)] = cell
if self._auto_prefix:
cell.update_parameters_name(prefix + str(idx) + ".")
self.cell_list = list(self._cells.values())
self._is_dynamic_name.insert(index, True)
class CellList(_CellListBase, Cell):
"""

View File

@ -43,3 +43,4 @@ class NodeType(Enum):
Input = 7
Output = 8
Tree = 9
CellContainer = 10

View File

@ -308,6 +308,16 @@ class PatternEngine:
queue.extend(inputs_dict.get(cur_node.get_name()))
return new_root
@staticmethod
def _multi_replace_cellcontainer(stree, cellcontainer, node, matched_dict, new_nodes):
"""Replace node in CellContainer."""
to_erase_list = list(matched_dict.values())
stree.replace(Node(node), new_nodes)
for n in reversed(to_erase_list):
if n.get_handler() is node:
continue
stree.erase_node(n)
def apply(self, stree: SymbolTree) -> bool:
"""
Apply current pattern to a `SymbolTree`.
@ -359,6 +369,9 @@ class PatternEngine:
visited.append(cur_node)
queue.extend(cur_node.get_users())
continue
if cur_node.get_node_type() == NodeType.CellContainer:
self._process_cellcontainer(stree, cur_node.get_handler())
continue
visited.append(cur_node)
matched, matched_dict = self._match(self._pattern, cur_node)
# not matched
@ -460,3 +473,21 @@ class PatternEngine:
logger.debug("Check match failed, pattern leaked")
return False
return True
def _process_cellcontainer(self, stree, cellcontainer):
"""Process CellContainer node."""
for node in cellcontainer.nodes():
if node.get_node_type() == NodeType.Tree:
subtree = node.symbol_tree
self.apply(SymbolTree(subtree))
continue
else:
matched, matched_dict = self._match(self._pattern, Node(node))
if not matched:
continue
new_nodes = []
if self._replacement is not None:
new_nodes = self._replacement(self._pattern, self._is_chain, matched_dict)
if not new_nodes: # if replacement is empty, do nothing
continue
PatternEngine._multi_replace_cellcontainer(stree, cellcontainer, node, matched_dict, new_nodes)

View File

@ -241,8 +241,10 @@ class AstModifier(ast.NodeTransformer):
An instance of ast.Assign which has been appended to 'init_func'.
"""
return AstModifier.insert_assign_to_function(init_func, targets=targets,
args=[ScopedValue.create_variable_value(field)],
expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"))
expr=ScopedValue(ValueType.NamingValue, "", "setattr"),
args=[ScopedValue(ValueType.NamingValue, "obj"),
ScopedValue.create_variable_value(field)])
@staticmethod
def create_call_assign(targets: [ScopedValue], expr: ScopedValue, args: [ScopedValue],

View File

@ -24,8 +24,6 @@ _ms_functional_ns = CellNamespace('mindspore.ops.functional')
def is_subtree(cls_name):
"""Determine whether 'cls_name' is a subtree."""
if cls_name == "SequentialCell":
return True
if cls_name == "QuantizeWrapperCell":
return False
if cls_name in _ms_common_ns or cls_name in _ms_nn_ns or cls_name in _ms_ops_ns:

View File

@ -624,7 +624,7 @@ class Node:
"""
self._targets = targets
if self._node_type in (NodeType.CallCell, NodeType.CallMethod, NodeType.CallPrimitive,
NodeType.Tree, NodeType.CallFunction):
NodeType.Tree, NodeType.CallFunction, NodeType.CellContainer):
self._sync_assign_targets_to_ast()
def get_func(self) -> ScopedValue:
@ -1135,7 +1135,7 @@ class Node:
def _sync_arg(self):
"""Sync _normalized_args to corresponding ast node when updated."""
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree):
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree, NodeType.CellContainer):
self._sync_call_cell_args_to_ast()
elif self._node_type == NodeType.Output:
self._sync_return_node_to_ast()
@ -1198,3 +1198,85 @@ class TreeNode(Node):
if ast_node is None:
ast_node = AstModifier.create_call_assign(new_targets, func, non_custom_args, non_custom_kwargs)
return cls(tree, ast_node, new_targets, func, args, kwargs, name, instance)
class CellContainer(Node):
""" Container for saving cell-objects node. """
class _Visitor():
""" A iterator of CellContainer nodes. """
def __init__(self, cellcontainer):
self._cellcontainer = cellcontainer
def __len__(self):
""" Get the number of nodes. """
return self._cellcontainer.node_count
def __iter__(self):
"""Create an iterator over the CellContainer."""
count = len(self._cellcontainer.node_list)
i = 0
while i < count:
curr = self._cellcontainer.node_list[i]
if curr.valid:
yield curr
i += 1
def __init__(self, ast_node: ast.AST, targets: [ScopedValue], func: ScopedValue,
args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance):
"""Constructor of CellContainer.
Args:
ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
func ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
Name of node also used as field name in network class.
instance: Object in network corresponding to this node.
"""
if isinstance(func, str):
func = ScopedValue.create_naming_value(func)
super().__init__(NodeType.CellContainer, ast_node, targets, func, args, kwargs, name, instance)
self._node_list = list()
self._node_count = 0
@property
def node_count(self):
"""Number of nodes."""
return self._node_count
@node_count.setter
def node_count(self, count):
"""Set number of nodes."""
self._node_count = count
@property
def node_list(self):
""" Get node list. """
return self._node_list
def append(self, node):
""" Append new node to node list. """
self._node_list.append(node)
self.get_instance().append(node.get_instance())
self.node_count += 1
def erase(self, node):
"""Erase node form container."""
index = self.node_list.index(node)
setattr(node, "valid", False)
self.node_count -= 1
index = self.get_instance().cell_list.index(node.get_instance())
del self.get_instance()[index]
def insert(self, index, node):
"""Insert node into container"""
self.node_list.insert(index, node)
setattr(node, "valid", True)
self.get_instance()._insert(index, node.get_instance())
self.node_count += 1
def nodes(self):
""" Return a iterator of node."""
return self._Visitor(self)

View File

@ -19,13 +19,13 @@ import astunparse
from mindspore import log as logger
from mindspore._extends.parse.namespace import CellNamespace
from mindspore.nn import Cell
from mindspore.nn import Cell, SequentialCell
from mindspore.ops import operations as P
from mindspore.ops import Primitive
from mindspore.rewrite.parser_register import ParserRegister
from mindspore.rewrite.namespace import is_subtree, is_functional, get_functional
from mindspore.rewrite.symbol_tree import SymbolTree
from mindspore.rewrite.node import Node, TreeNode
from mindspore.rewrite.node import Node, TreeNode, CellContainer
from mindspore.rewrite.parser import Parser
from mindspore.rewrite.parser_register import reg_parser
from mindspore.rewrite.api.scoped_value import ScopedValue, ValueType
@ -286,10 +286,10 @@ class AssignParser(Parser):
if target.attr != func_name:
continue
changed = True
global_vars_key = "_".join([func_name, "args"])
stree.add_global_vars(global_vars_key, sub_tree.get_global_vars())
args_call = AstModifier.create_call(ScopedValue.create_naming_value("get", "global_vars"),
[ScopedValue.create_variable_value(global_vars_key)])
setattr(stree.get_origin_network(), func_name, sub_tree.get_origin_network())
args_call = AstModifier.create_call(ScopedValue(ValueType.NamingValue, "", "getattr"),
[ScopedValue(ValueType.NamingValue, "", "obj"),
ScopedValue(ValueType.StringValue, "", func_name)])
body.value = ast.Call(func=ast.Name(class_name, ast.Store()), args=[args_call], keywords=[])
break
return changed
@ -308,6 +308,37 @@ class AssignParser(Parser):
call_args = [AssignParser._create_scopedvalue(arg) for arg in father_ast_node.value.args]
return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, {})
def _cell_container_process(self, ast_node, stree, targets, func, call_args, call_kwargs, op_name, container_obj):
""" parse cell container object."""
cell_container = CellContainer(ast_node, targets, func, call_args, call_kwargs, op_name, container_obj)
for i, cell in enumerate(container_obj):
is_sub_tree = is_subtree(type(cell).__name__)
if is_sub_tree:
stb = SymbolTreeBuilder(cell)
new_stree = stb.build()
replacer = AstReplacer(new_stree.get_class_ast())
replacer.replace_all(new_stree.get_ori_cls_name(), new_stree.get_opt_cls_name())
tree = TreeNode.create_tree_node(new_stree, ast_node, targets, func, call_args, call_kwargs,
type(cell).__name__, cell)
setattr(tree, "container", cell_container)
setattr(tree, "valid", True)
tree.set_belong_symbol_tree(stree)
cell_container.node_list.append(tree)
cell_container.node_count += 1
if i > 0:
tree.set_inputs([cell_container.node_list[i-1]])
else:
node = Node.create_call_buildin_op(cell, ast_node, targets, func, call_args, call_kwargs,
type(cell).__name__)
setattr(node, "container", cell_container)
setattr(node, "valid", True)
node.set_belong_symbol_tree(stree)
cell_container.node_list.append(node)
cell_container.node_count += 1
if i > 0:
node.set_inputs([cell_container.node_list[i-1]])
return cell_container
def _convert_ast_call_to_node(self, ast_node: ast.Call, father_ast_node: ast.Assign, stree: SymbolTree) -> Node:
"""
Convert ast.Call to a symbol tree node.
@ -343,6 +374,10 @@ class AssignParser(Parser):
return node
raise RuntimeError(error_str(f"operator instance undefined.",
child_node=ast_node.func, father_node=ast_node))
if isinstance(op, SequentialCell):
node = self._cell_container_process(father_ast_node, stree, targets, func, call_args, call_kwargs,
func_name, op)
return node
if isinstance(op, Primitive):
return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, call_kwargs, func_name)
if isinstance(op, Cell):

View File

@ -21,8 +21,7 @@ from mindspore._extends.parse.namespace import CellNamespace
from ..symbol_tree import SymbolTree
from ..parser import Parser
from ..parser_register import ParserRegister, reg_parser
from ..api.scoped_value import ScopedValue
from ..ast_helpers import AstReplacer, AstModifier
from ..ast_helpers import AstReplacer
from ..common import error_str
@ -124,9 +123,6 @@ class ClassDefParser(Parser):
super_index = ClassDefParser._find_super_expr_of_init_func(init_ast)
ClassDefParser._modify_arguments_of_init_func(init_ast)
self._replace_ori_field_of_init_func(stree, init_ast.body, super_index)
# re-find super_index for init_func changed in _replace_ori_field_of_init_func
super_index = ClassDefParser._find_super_expr_of_init_func(init_ast)
ClassDefParser._insert_handler_to_init_func(init_ast, super_index)
@staticmethod
def _find_super_expr_of_init_func(ast_init_fn: ast.FunctionDef) -> int:
@ -158,7 +154,7 @@ class ClassDefParser(Parser):
def _modify_arguments_of_init_func(ast_init_fn: ast.FunctionDef):
"""Replace init function input parameters to self and global_vars."""
arg_self = ast.arg(arg="self", annotation="")
arg_global_vars = ast.arg(arg="global_vars", annotation="")
arg_global_vars = ast.arg(arg="obj", annotation="")
ast_init_fn.args = ast.arguments(args=[arg_self, arg_global_vars], posonlyargs=[], kwonlyargs=[],
kw_defaults=[], defaults=[], vararg=None, kwarg=None)
ast.fix_missing_locations(ast_init_fn)
@ -235,22 +231,12 @@ class ClassDefParser(Parser):
continue
field_name = target.attr
body.value = ast.Call(ast.Name('getattr', ast.Load()),
[ast.Attribute(ast.Name('self', ast.Load()), '_handler', ast.Load()),
[ast.Name('obj', ast.Load()),
ast.Constant(value=field_name, kind=None)], [])
for counter, index in enumerate(body_index_to_be_deleted):
bodies.pop(index - counter)
ClassDefParser._remove_empty_ast_in_init_func(bodies)
@staticmethod
def _insert_handler_to_init_func(ast_init_fn: ast.FunctionDef, super_index):
"""Insert 'self._handler = global_vars.get('handler')' to init ast.FunctionDef.body"""
if super_index == -1:
super_index = 0
AstModifier.insert_assign_to_function(ast_init_fn, [ScopedValue.create_naming_value("_handler", "self")],
ScopedValue.create_naming_value("get", "global_vars"),
[ScopedValue.create_variable_value("handler")], None,
ast_init_fn.body[super_index], False)
def process(self, stree: SymbolTree, node: ast.ClassDef):
"""
Parse init and construct in ast.ClassDef.

View File

@ -34,12 +34,13 @@ class ForParser(Parser):
def modify_init_ast(stree, i, obj, iter_var_name):
"""Modify the ast node in init function."""
target = f"{iter_var_name.strip()}_{str(i)}"
stree.add_global_vars(target, obj)
setattr(stree.get_origin_network(), target, obj)
stree.get_origin_network().insert_child_to_cell(target, obj)
AstModifier.insert_assign_to_function(stree.get_init_func_ast(),
targets=[ScopedValue(ValueType.NamingValue, "self", target)],
expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"),
args=[ScopedValue(ValueType.StringValue, "", target)])
expr=ScopedValue(ValueType.NamingValue, "", "getattr"),
args=[ScopedValue(ValueType.NamingValue, "", "obj"),
ScopedValue(ValueType.StringValue, "", target)])
@staticmethod
def modify_construct_ast(stree, ast_node, old_name, new_name):

View File

@ -44,7 +44,7 @@ class FunctionDefParser(Parser):
else:
parser.process(stree, body)
for body in node.body:
for body in node.body[::-1]:
if isinstance(body, (ast.For, ast.If)):
node.body.remove(body)
if hasattr(node, "decorator_list"):

View File

@ -160,7 +160,6 @@ class SymbolTree(Observer, Observable):
self._topo_mgr = TopoManager()
self._topo_mgr.reg_observer(self)
self._global_vars: {str, object} = {origin_network_key: origin_network}
self._nodes: {str, Node} = {}
# parameters of forward method
self._inputs: [Node] = []
@ -484,17 +483,6 @@ class SymbolTree(Observer, Observable):
"""
return self._origin_network
def get_global_vars(self):
"""Get global variables."""
return self._global_vars
def add_global_vars(self, key: str, value):
"""Add global variables."""
if self._global_vars.get(key) is not None:
logger.info(f"The key '{key}' is duplicated")
return
self._global_vars[key] = value
def get_nodes_dict(self):
"""Get dict of nodes"""
return self._nodes
@ -614,7 +602,6 @@ class SymbolTree(Observer, Observable):
RuntimeError: If 'node_or_name' is not belong to this SymbolTree or any sub-SymbolTree of current
SymbolTree.
"""
node = self._get_real_node(node_or_name)
if node is None:
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
@ -653,7 +640,12 @@ class SymbolTree(Observer, Observable):
RuntimeError: If 'position' is not in current SymbolTree.
RuntimeError: If corresponding ast node is not an ast.Assign when 'insert_to_ast' is True.
"""
if position is not None and hasattr(position.node, "container"):
cellcontainer = getattr(position.node, "container")
index = cellcontainer.node_list.index(position.node)
index = index if position.before_node else index + 1
cellcontainer.insert(index, node)
return node
# if position in current SymbolTree
if position is not None and position.symbol_tree is not self:
raise RuntimeError("Position is not in current SymbolTree:", position)
@ -683,10 +675,10 @@ class SymbolTree(Observer, Observable):
if not isinstance(node_ast, ast.Assign):
raise RuntimeError("Only support insert cell op now")
if isinstance(node, TreeNode):
global_vars_key = node.get_name() + "_args"
self.add_global_vars(global_vars_key, node.symbol_tree.get_global_vars())
args_call = AstModifier.create_call(ScopedValue.create_naming_value("get", "global_vars"),
[ScopedValue.create_variable_value(global_vars_key)])
setattr(self._origin_network, node.get_name(), node.get_instance())
args_call = AstModifier.create_call(ScopedValue(ValueType.NamingValue, "", "getattr"),
[ScopedValue(ValueType.NamingValue, "", "obj"),
ScopedValue(ValueType.StringValue, "", node.get_name())])
value = ast.Call(func=ast.Name(node.symbol_tree.get_opt_cls_name(), ast.Store(), lineno=0,
col_offset=0), args=[args_call], keywords=[], lineno=0, col_offset=0)
@ -703,12 +695,13 @@ class SymbolTree(Observer, Observable):
else:
AstModifier.insert_assign_to_function(self._init_func_ast,
targets=[ScopedValue(ValueType.NamingValue, "self", node_name)],
expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"),
args=[ScopedValue(ValueType.StringValue, "", node_name)])
expr=ScopedValue(ValueType.NamingValue, "", "getattr"),
args=[ScopedValue(ValueType.NamingValue, "", "obj"),
ScopedValue(ValueType.StringValue, "", node_name)])
AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
None if position is None else position.node.get_ast(),
position.before_node)
self._global_vars[node_name] = node.get_instance()
setattr(self._origin_network, node_name, node.get_instance())
return node
def append_node(self, node: Node, append_to_ast: bool = True) -> Node:
@ -851,6 +844,10 @@ class SymbolTree(Observer, Observable):
node = self._get_real_node(node_or_name)
if node is None:
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
if hasattr(node, "container"):
cellcontainer = getattr(node, "container")
cellcontainer.erase(node)
return node
ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
if not ret:
raise RuntimeError("node not in function ast tree.")
@ -884,6 +881,9 @@ class SymbolTree(Observer, Observable):
RuntimeError: If 'old_node' is not belong to current SymbolTree.
"""
if hasattr(old_node, "container"):
self._replace_container_node(old_node, new_nodes)
return new_nodes[0]
real_old_node = self._get_real_node(old_node)
if real_old_node is None:
raise RuntimeError("Old node is not belong to current SymbolTree:", old_node)
@ -1026,7 +1026,7 @@ class SymbolTree(Observer, Observable):
A network object.
"""
cls = self._get_cls_through_file()
return cls(self._global_vars)
return cls(self._origin_network)
def set_saved_file_name(self, file_name: str):
"""Sets the filename used to save the network."""
@ -1070,6 +1070,14 @@ class SymbolTree(Observer, Observable):
else:
body.names.remove(alias)
def _replace_container_node(self, old_node, new_nodes):
cellcontainer = getattr(old_node, "container")
index = cellcontainer.node_list.index(old_node)
for n in reversed(new_nodes):
cellcontainer.insert(index, n)
index = cellcontainer.node_list.index(old_node)
cellcontainer.erase(old_node)
def _filter_out_to_delete_field(self, to_delete_field):
"""filter out used field from `to_delete_field`"""
# filter _handler field
@ -1077,7 +1085,8 @@ class SymbolTree(Observer, Observable):
to_delete_field.pop("_handler")
# filter field used in node of construct
for node in self._nodes.values():
if node.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree):
if node.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree,
NodeType.CellContainer):
func: ScopedValue = node.get_func()
if func.scope == "self" and to_delete_field.get(func.value):
to_delete_field.pop(func.value)
@ -1144,12 +1153,9 @@ class SymbolTree(Observer, Observable):
self._module_ast.body.remove(body)
def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
if isinstance(node_or_name, Node):
result = self.get_node(node_or_name.get_name())
return result if result is node_or_name else None
if isinstance(node_or_name, str):
return self.get_node(node_or_name)
return None
return node_or_name
def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node:
"""
@ -1298,7 +1304,7 @@ class SymbolTree(Observer, Observable):
raise TypeError("value should be ScopedValue, got: ", type(value))
if value.type == ValueType.CustomObjValue:
field = self._node_name_namer.get_name(f"var_{type(value.value).__name__}")
self._global_vars[field] = value.value
setattr(self._origin_network, field, value.value)
init_targets = [ScopedValue.create_naming_value(field, "self")]
AstModifier.append_global_vars_expr_to_init(self._init_func_ast, init_targets, field)
result[arg] = init_targets[0]
@ -1316,7 +1322,8 @@ class SymbolTree(Observer, Observable):
Returns:
A class handle.
"""
file_name = "new_network_{0}.py".format(int(time.time() * 10000))
self._update_container()
file_name = "new_network_{0}.py".format(int(time.time() * 10000000))
with os.fdopen(os.open(file_name, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
source = self.get_code()
f.write(source.encode('utf-8'))
@ -1333,3 +1340,18 @@ class SymbolTree(Observer, Observable):
def _on_change(self, event: Event):
self._modified = True
self.changed(event)
def _update_container(self):
"""Update instance of node in container."""
for node in self.nodes():
index = 0
if node.get_node_type() == NodeType.CellContainer:
for n in node.node_list:
if not n.valid:
continue
if n.get_node_type() == NodeType.Tree:
obj = n.symbol_tree.get_network()
node.get_instance()[index] = obj
else:
node.get_instance()[index] = n.get_instance()
index += 1

View File

@ -100,7 +100,11 @@ def _insert_cast_operator(stree):
if node.get_targets() is None:
continue
in_white_list = False
if node.get_node_type() != ms.rewrite.NodeType.Tree:
if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
for n in node.get_handler().node_list:
if n.get_node_type() == ms.rewrite.NodeType.Tree:
_insert_cast_operator(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)))
elif node.get_node_type() != ms.rewrite.NodeType.Tree:
# insert cast before the primitive operators in white_list
if node.get_instance_type() in AMP_WHITE_LIST_OPS:
in_white_list = True
@ -165,7 +169,11 @@ def _remove_duplicated_cast(stree):
for node in stree.nodes():
if node.get_targets() is None:
continue
if node.get_node_type() != ms.rewrite.NodeType.Tree:
if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
for n in node.get_handler().node_list:
if n.get_node_type() == ms.rewrite.NodeType.Tree:
_remove_duplicated_cast(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)))
elif node.get_node_type() != ms.rewrite.NodeType.Tree:
if node.get_instance_type() == P.Cast and _removed_cast_pair(node):
# remove the following cast node first
len_users = len(node.get_users())

View File

@ -0,0 +1,402 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cell container."""
from mindspore import nn
from mindspore.ops import operations as P
from mindspore.rewrite import SymbolTree, NodeType, TreeNodeHelper, Node, ScopedValue, PatternEngine, Replacement, \
PatternNode
def _conv3x3(in_channel, out_channel, stride=1):
return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
padding=0, pad_mode='same', weight_init="ones")
def _conv1x1(in_channel, out_channel, stride=1):
return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
padding=0, pad_mode='same', weight_init="ones")
def _bn(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
class ResidualBlock(nn.Cell):
expansion = 4
def __init__(self,
in_channel,
out_channel,
stride=1):
super(ResidualBlock, self).__init__()
self.stride = stride
channel = out_channel // self.expansion
self.conv1 = _conv1x1(in_channel, channel, stride=1)
self.bn1 = _bn(channel)
self.conv2 = _conv3x3(channel, channel, stride=stride)
self.bn2 = _bn(channel)
self.conv3 = _conv1x1(channel, out_channel, stride=1)
self.bn3 = _bn(out_channel)
self.relu = nn.ReLU()
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), _bn(out_channel)])
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
identity = self.down_sample_layer(identity)
out = out + identity
out = self.relu(out)
return out
class ResNetSimple(nn.Cell):
def __init__(self):
super(ResNetSimple, self).__init__(auto_prefix=True)
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad', weight_init="ones")
self.bn1 = _bn(16)
self.relu = P.ReLU()
self.layer1 = self._make_layer(ResidualBlock, 3, in_channel=63, out_channel=256, stride=1)
self.layer1.append(self.conv1)
self.layer1.append(self.bn1)
self.reshape = P.Reshape()
self.out_channels = 10
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
return x
def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
layers = []
resnet_block = block(in_channel, out_channel, stride=stride)
layers.append(resnet_block)
for _ in range(1, layer_num):
resnet_block = ResidualBlock(out_channel, out_channel, stride=1)
layers.append(resnet_block)
return nn.SequentialCell(layers)
def test_cellcontainer_parse():
"""
Feature: parse CellContainer Node.
Description: parse a network with SequentialCell object.
Expectation: Rewrite can parse a network with SquentialCell object successfully.
"""
net = ResNetSimple()
stree = SymbolTree.create(net)
for node in stree.nodes():
if node.get_node_type() == NodeType.CellContainer:
assert len(node.get_handler().node_list) == 5
for i, n in enumerate(node.get_handler().node_list):
if i < 3:
assert n.get_instance_type() is ResidualBlock
if i == 3:
assert n.get_instance_type() is nn.Conv2d
if i == 4:
assert n.get_instance_type() is nn.BatchNorm2d
def test_cellcontainer_insert():
"""
Feature: modify CellContainer Node.
Description: using node in container to set insert location.
Expectation: raise ValueError.
"""
net = ResNetSimple()
stree = SymbolTree.create(net)
for node in stree.nodes():
if node.get_node_type() == NodeType.CellContainer:
assert len(node.get_handler().nodes()) == 5
for n in node.get_handler().nodes():
if n.get_instance_type() is nn.Conv2d:
position = stree.before(Node(n))
new_conv = nn.Conv2d(16, 16, 3)
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
args=[ScopedValue.create_naming_value('self_max_po')])
stree.insert(position, new_conv_node)
break
assert len(node.get_handler().nodes()) == 6
assert node.get_handler().node_list[3].get_name() == "new_conv"
def test_cellcontainer_insert_ok():
"""
Feature: modify CellContainer Node.
Description: Inserts a node within a tree node in CellContainer Node.
Expectation: Insertion succeeded.
"""
def _insert_conv(stree: SymbolTree):
for node in stree.nodes():
if node.get_instance_type() == nn.BatchNorm2d:
position = stree.after(node)
new_conv = nn.Conv2d(16, 16, 3)
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
args=[ScopedValue.create_naming_value('self_max_po')])
stree.insert(position, new_conv_node)
break
net = ResNetSimple()
stree = SymbolTree.create(net)
for node in stree.nodes():
if node.get_node_type() == NodeType.CellContainer:
for n in node.get_handler().node_list:
if n.get_node_type() == NodeType.Tree:
_insert_conv(TreeNodeHelper.get_sub_tree(Node(n)))
break
new_net = stree.get_network()
cell_container = getattr(new_net, "layer1")
assert hasattr(cell_container._cells["0"], "new_conv")
def test_cellcontainer_insert_to_subtree():
"""
Feature: modify CellContainer Node.
Description: Inserts a node within a tree node in CellContainer Node.
Expectation: Insertion succeeded.
"""
def _insert_conv(stree: SymbolTree):
for node in stree.nodes():
if node.get_instance_type() == nn.BatchNorm2d:
position = stree.after(node)
new_conv = nn.Conv2d(16, 16, 3)
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
args=[ScopedValue.create_naming_value('self_max_po')])
stree.insert(position, new_conv_node)
break
net = ResNetSimple()
stree = SymbolTree.create(net)
for node in stree.nodes():
if node.get_node_type() == NodeType.CellContainer:
for n in node.get_handler().node_list:
if n.get_node_type() == NodeType.Tree:
_insert_conv(TreeNodeHelper.get_sub_tree(Node(n)))
break
new_net = stree.get_network()
cell_container = getattr(new_net, "layer1")
assert hasattr(cell_container._cells["0"], "new_conv")
def test_cellcontainer_del():
"""
Feature: modify CellContainer Node.
Description: delete the CellContainer Node.
Expectation: success.
"""
net = ResNetSimple()
stree = SymbolTree.create(net)
original_nodes_size = len(stree.get_handler()._nodes)
for node in stree.nodes():
if node.get_node_type() == NodeType.CellContainer and node.get_name() == "layer1":
users = node.get_users()
for user in users:
user.set_arg(0, "x")
stree.erase_node(node)
assert len(stree.get_handler()._nodes) == original_nodes_size - 1
def test_cellcontainer_del_node():
"""
Feature: modify CellContainer Node.
Description: delete the CellContainer Node.
Expectation: success.
"""
net = ResNetSimple()
stree = SymbolTree.create(net)
for node in stree.nodes():
if node.get_node_type() == NodeType.CellContainer and node.get_name() == "layer1":
assert len(node.get_handler().nodes()) == 5
for n in node.get_handler().nodes():
users = node.get_users()
inputs = node.get_inputs()
for user in users:
user.set_arg_by_node(0, inputs[0])
stree.erase_node(Node(n))
break
assert len(node.get_handler().nodes()) == 4
def test_cellcontainer_del_node_in_subtree():
"""
Feature: modify CellContainer Node.
Description: delete a node within a tree node in CellContainer Node.
Expectation: success.
"""
def _del_node(sub_tree):
for _node in sub_tree.nodes():
if _node.get_name() == "conv2":
users = Node(_node).get_users()
for user in users:
user.set_arg(0, "out")
sub_tree.erase_node(_node)
net = ResNetSimple()
stree = SymbolTree.create(net)
for node in stree.nodes():
if node.get_node_type() == NodeType.CellContainer:
for i, n in enumerate(node.get_handler().node_list):
if n.get_node_type() == NodeType.Tree and i == 1:
sub_tree = n.symbol_tree
original_nodes_size = len(sub_tree._nodes)
_del_node(sub_tree)
assert len(sub_tree._nodes) == original_nodes_size - 1
new_net = stree.get_network()
cell_container = getattr(new_net, "layer1")
assert not hasattr(cell_container._cells["1"], "conv2")
def test_cellcontainer_replace():
"""
Feature: modify CellContainer Node.
Description: replace CellContainer Node with another Node.
Expectation: success.
"""
def _replace_bn(stree: SymbolTree):
for node in stree.nodes():
if node.get_node_type() == NodeType.CellContainer:
new_conv = nn.Conv2d(16, 16, 3)
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
args=[ScopedValue.create_naming_value('x')])
stree.replace(node, [new_conv_node])
break
net = ResNetSimple()
stree = SymbolTree.create(net)
_replace_bn(stree)
new_net = stree.get_network()
assert not hasattr(new_net, "layer1")
assert hasattr(new_net, "new_conv")
def test_cellcontainer_replace_node():
"""
Feature: modify CellContainer Node.
Description: replace the CellContainer Node.
Expectation: success.
"""
net = ResNetSimple()
stree = SymbolTree.create(net)
for node in stree.nodes():
if node.get_node_type() == NodeType.CellContainer and node.get_name() == "layer1":
for n in node.get_handler().nodes():
new_conv = nn.Conv2d(16, 16, 3)
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
args=[ScopedValue.create_naming_value('x')])
stree.replace(Node(n), [new_conv_node])
break
assert node.get_handler().node_list[0].get_name() == "new_conv"
assert isinstance(node.get_handler().get_instance()._cells["0"], nn.Conv2d)
break
def test_cellcontainer_replace_in_subtree():
"""
Feature: modify CellContainer Node.
Description: replace a node within a tree node in CellContainer Node.
Expectation: success.
"""
def _replace_bn(stree: SymbolTree):
for node in stree.nodes():
if node.get_name() == "bn1":
new_conv = nn.Conv2d(16, 16, 3)
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
args=[ScopedValue.create_naming_value('self_max_po')])
stree.replace(node, [new_conv_node])
break
net = ResNetSimple()
stree = SymbolTree.create(net)
for node in stree.nodes():
if node.get_node_type() == NodeType.CellContainer:
for n in node.get_handler().node_list:
if n.get_node_type() == NodeType.Tree:
_replace_bn(TreeNodeHelper.get_sub_tree(Node(n)))
break
new_net = stree.get_network()
cell_container = getattr(new_net, "layer1")
assert not hasattr(cell_container._cells["0"], "bn1")
assert hasattr(cell_container._cells["0"], "new_conv")
def test_cellcontainer_pattern():
"""
Feature: modify CellContainer Node.
Description: apply pattern matching and replacement on the network containing SequentialCell object.
Expectation: success.
"""
class ConvBnReplacement(Replacement):
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched):
assert is_chain_pattern
assert pattern.type() == nn.BatchNorm2d
bn_node: Node = matched.get(pattern.name())
assert bn_node is not None
assert len(pattern.get_inputs()) == 1
add_pattern = pattern.get_inputs()[0]
assert add_pattern.type() == nn.Conv2d
add_node: Node = matched.get(add_pattern.name())
assert add_node is not None
assert not add_pattern.get_inputs()
new_maxpool1 = nn.MaxPool2d()
new_maxpool1_node = Node.create_call_cell(new_maxpool1, ['new_maxpool1'], add_node.get_args())
new_relu1 = nn.ReLU()
new_relu1_node = Node.create_call_cell(new_relu1, ['new_relu_1'],
[ScopedValue.create_naming_value('new_maxpool1')])
new_relu2 = nn.ReLU()
new_relu2_node = Node.create_call_cell(new_relu2, ['new_relu_2'],
[ScopedValue.create_naming_value('new_maxpool1')])
new_maxpool2 = nn.BiDense(1, 1, 2)
new_maxpool2_node = Node.create_call_cell(new_maxpool2, ['new_maxpool2'],
[ScopedValue.create_naming_value('new_relu_1'),
ScopedValue.create_naming_value('new_relu_2')])
return [new_maxpool1_node, new_relu1_node, new_relu2_node, new_maxpool2_node]
class ConvReluPattern(PatternEngine):
def __init__(self):
super().__init__([nn.Conv2d, nn.BatchNorm2d], ConvBnReplacement())
net = ResNetSimple()
stree = SymbolTree.create(net)
_pattern = ConvReluPattern()
_pattern.apply(stree)
new_net = stree.get_network()
cell_container = getattr(new_net, "layer1")
assert not hasattr(cell_container, "conv1")
assert not hasattr(cell_container, "bn1")
assert not hasattr(cell_container._cells["0"], "conv1")
assert not hasattr(cell_container._cells["1"], "conv1")
assert not hasattr(cell_container._cells["2"], "conv1")
assert hasattr(cell_container._cells["0"], "new_relu")
assert hasattr(cell_container._cells["0"], "new_maxpool1")
assert isinstance(getattr(getattr(cell_container._cells["0"], "down_sample_layer"), "0"), nn.MaxPool2d)
assert hasattr(cell_container._cells["1"], "new_relu")
assert hasattr(cell_container._cells["1"], "new_maxpool1")
assert isinstance(getattr(getattr(cell_container._cells["1"], "down_sample_layer"), "0"), nn.MaxPool2d)
assert hasattr(cell_container._cells["2"], "new_relu")
assert hasattr(cell_container._cells["2"], "new_maxpool1")
assert isinstance(getattr(getattr(cell_container._cells["2"], "down_sample_layer"), "0"), nn.MaxPool2d)
assert isinstance(getattr(cell_container, "3"), nn.MaxPool2d)
assert isinstance(getattr(cell_container, "4"), nn.ReLU)
assert isinstance(getattr(cell_container, "6"), nn.BiDense)

View File

@ -2,6 +2,7 @@
from collections import OrderedDict
from mindspore import nn
from mindspore.ops import operations as P
from mindspore.rewrite import SymbolTree, PatternEngine, Replacement, PatternNode, Node, ScopedValue
from mindspore.rewrite.api.tree_node_helper import TreeNodeHelper
from mindspore.rewrite.api.node_type import NodeType
@ -108,6 +109,28 @@ class CellBlock(nn.Cell):
return out
class SimpleNet(nn.Cell):
def __init__(self):
super().__init__()
self.mul = P.Mul()
self.dense = nn.Dense(in_channels=32, out_channels=32, weight_init="ones")
self.mean = P.ReduceMean(keep_dims=False)
self.split = P.Split(axis=1, output_num=3)
self.conv1 = nn.Conv2d(3, 64, 3, stride=2)
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.block = CellBlock(3, 6)
def construct(self, x):
y, _, _ = self.split(x)
y = self.mean(y, (2, 3))
x = self.mul(x, 1)
x = self.block(x)
x = self.conv1(x)
x = self.max_pool2d(x)
x = self.dense(x)
return x, y
class ForNetWithSubTree(nn.Cell):
def __init__(self):
super(ForNetWithSubTree, self).__init__()
@ -125,12 +148,14 @@ class ForNetWithSubTree(nn.Cell):
resnet_block3 = CellBlock(16, 32)
layers = [resnet_block1, resnet_block2, resnet_block3]
self.layer2 = nn.SequentialCell(layers)
self.simple_net = SimpleNet()
def construct(self, x):
x = self.conv1(x)
x = self.layer1(x)
x = self.relu(x)
x = self.layer2(x)
x = self.simple_net(x)
return x
@ -144,7 +169,7 @@ def test_erase_subtree_node():
stree = SymbolTree.create(net)
for node in stree.nodes():
if node.get_name() == "layer1":
if node.get_name() == "simple_net":
subtree = TreeNodeHelper.get_sub_tree(node)
orig_node_num = len(subtree.get_handler()._nodes)
for n in subtree.nodes():
@ -169,11 +194,11 @@ def test_erase_subtree_node_01():
stree = SymbolTree.create(net)
for node in stree.nodes():
if node.get_name() == "layer2":
if node.get_name() == "simple_net":
subtree = TreeNodeHelper.get_sub_tree(node)
orig_node_num = len(subtree.get_handler()._nodes)
for n in subtree.nodes():
if n.get_name() == "cell_list_1":
if n.get_name() == "block":
input_node = n.get_inputs()[0]
output_nodes = n.get_users()
for _nn in output_nodes:
@ -203,10 +228,10 @@ def test_erase_subtree_node_02():
net = ForNetWithSubTree()
stree = SymbolTree.create(net)
for node in stree.nodes():
if node.get_name() == "layer2":
if node.get_name() == "simple_net":
subtree = TreeNodeHelper.get_sub_tree(node)
for n in subtree.nodes():
if n.get_name() == "cell_list_1":
if n.get_name() == "block":
subtree1 = TreeNodeHelper.get_sub_tree(n)
_remove_bn(subtree1)
assert subtree1.get_node("bn1") is None
@ -231,10 +256,10 @@ def test_insert_subtree_node():
net = ForNetWithSubTree()
stree = SymbolTree.create(net)
for node in stree.nodes():
if node.get_name() == "layer2" and node.get_node_type() == NodeType.Tree:
if node.get_name() == "simple_net" and node.get_node_type() == NodeType.Tree:
subtree = TreeNodeHelper.get_sub_tree(node)
for n in subtree.nodes():
if n.get_name() == "cell_list_1":
if n.get_name() == "block":
subtree1 = TreeNodeHelper.get_sub_tree(n)
orig_node_num = len(subtree1.get_handler()._nodes)
_insert_node(subtree1)
@ -251,7 +276,7 @@ def test_resnet_replace_121():
stree: SymbolTree = SymbolTree.create(net)
original_nodes_size = len(stree.get_handler()._nodes)
for node in stree.nodes():
if node.get_name() == "layer1" and node.get_node_type() == NodeType.Tree:
if node.get_name() == "simple_net" and node.get_node_type() == NodeType.Tree:
subtree = TreeNodeHelper.get_sub_tree(node)
for n in subtree.nodes():
if n.get_instance_type() == nn.Conv2d:
@ -274,7 +299,7 @@ def test_resnet_replace_12m():
stree: SymbolTree = SymbolTree.create(net)
for node in stree.nodes():
if node.get_name() == "layer1" and node.get_node_type() == NodeType.Tree:
if node.get_name() == "simple_net" and node.get_node_type() == NodeType.Tree:
subtree = TreeNodeHelper.get_sub_tree(node)
original_nodes_size = len(subtree.get_handler()._nodes)
for n in subtree.nodes():
@ -301,7 +326,7 @@ def test_node_fusion_in_subtree():
stree: SymbolTree = SymbolTree.create(net)
original_nodes_size = len(stree.get_handler()._nodes)
for node in stree.nodes():
if node.get_name() == "layer1" and node.get_node_type() == NodeType.Tree:
if node.get_name() == "simple_net" and node.get_node_type() == NodeType.Tree:
subtree = TreeNodeHelper.get_sub_tree(node)
original_nodes_size = len(subtree.get_handler()._nodes)
for n in subtree.nodes():

View File

@ -133,7 +133,6 @@ def test_simple_net():
net = SimpleNet(10)
stree = SymbolTree.create(net)
transform(stree)
print("------------------------------------ keys of global_vars: ", stree.get_handler().get_global_vars().keys())
net_opt = stree.get_network()
data_in = Tensor(np.ones([1, 1, 32, 32]), mindspore.float32)
_cell_graph_executor.compile(net_opt, data_in)