!33747 rewrite subgraph

Merge pull request !33747 from 于振华/rewrite_subgraph01
This commit is contained in:
i-robot 2022-05-05 07:07:47 +00:00 committed by Gitee
commit ad304d342e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
16 changed files with 323 additions and 78 deletions

View File

@ -88,7 +88,7 @@ class Node:
RuntimeError: If value of kwarg in `kwargs` is not a `NamingValue`-`ScopedValue` or a
`CustomObjValue`-`ScopedValue`.
"""
return Node(NodeImpl.create_call_buildin_op(cell, None, targets, ScopedValue.create_naming_value(name, "self"),
return Node(NodeImpl.create_call_op(cell, None, targets, ScopedValue.create_naming_value(name, "self"),
args, kwargs, name))
def get_prev(self) -> 'Node':

View File

@ -51,14 +51,15 @@ class SymbolTree:
"""
return self._symbol_tree
def nodes(self) -> {}:
def nodes(self):
"""
Get all nodes of corresponding network.
Returns:
A dict mapping from name of node to node.
"""
return [Node(node_impl) for node_impl in self._symbol_tree.nodes(unfold_subtree=False)]
for node in self._symbol_tree.nodes():
yield Node(node)
def get_node(self, node_name: str) -> Optional[Node]:
"""

View File

@ -19,7 +19,6 @@ from .symbol_tree import SymbolTree
from .node import Node
from .node_type import NodeType
from ..symbol_tree import SymbolTree as SymbolTreeImpl
from ..node import TreeNode
class TreeNodeHelper:
@ -47,7 +46,6 @@ class TreeNodeHelper:
if node.get_node_type() == NodeType.Tree:
node_impl = node.get_handler()
assert isinstance(node_impl, TreeNode)
subtree: SymbolTreeImpl = node_impl.symbol_tree
if subtree is None:
return None

View File

@ -0,0 +1,31 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define the namespace of MindSpore op definition."""
from .._extends.parse.namespace import CellNamespace
_ms_common_ns = CellNamespace('mindspore.common')
_ms_nn_ns = CellNamespace('mindspore.nn')
_ms_ops_ns = CellNamespace('mindspore.ops')
def is_subtree(cls_name):
"""Determine whether 'cls_name' is a subtree."""
if cls_name == "SequentialCell":
return True
if cls_name in _ms_common_ns or cls_name in _ms_nn_ns or cls_name in _ms_ops_ns:
return False
return True

View File

@ -23,6 +23,8 @@ from mindspore import log as logger
from .ast_helpers import AstModifier
from .api.scoped_value import ScopedValue, ValueType
from .api.node_type import NodeType
from .namespace import is_subtree
from .ast_helpers.ast_replacer import AstReplacer
PASS_THROUGH_METHOD = ScopedValue.create_naming_value("PassThrough")
@ -223,6 +225,41 @@ class Node:
return cls(NodeType.Output, ast_node, None, ScopedValue.create_naming_value("return"), real_return_values, {},
name, None)
@staticmethod
def create_call_op(op: Union[Cell, Primitive], ast_node: Optional[ast.AST],
targets: [Union[ScopedValue, str]], func: Union[ScopedValue, str],
args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None, name: str = ""):
"""
Static method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`.
If op is custom defined, it is treated by TreeNode.
A `CallCell` node represents an invoking to cell-op.
A `CallPrimitive` node represents an invoking to primitive-op.
Args:
op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node.
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
func ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
class.
name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`.
Name of node also used as field name in network class.
"""
cls_name = type(op).__name__
if is_subtree(cls_name):
from .symbol_tree_builder import SymbolTreeBuilder
stb = SymbolTreeBuilder(op)
stree = stb.build()
replacer = AstReplacer(stree.get_class_ast())
replacer.replace_all(stree.get_ori_cls_name(), stree.get_opt_cls_name())
return TreeNode.create_tree_node(stree, None, targets, ScopedValue.create_naming_value(name, "self"),
args, kwargs, name, op)
return Node.create_call_buildin_op(op, None, targets, ScopedValue.create_naming_value(name, "self"),
args, kwargs, name)
@staticmethod
def _get_construct_arg_names(parameters):
"""
@ -440,6 +477,10 @@ class Node:
"""
return self._prev
def set_prev(self, prev):
"""Set previous node of current node in source code order. """
self._prev = prev
def get_next(self) -> 'Node':
"""
Get next node of current node in source code order.
@ -449,6 +490,10 @@ class Node:
"""
return self._next
def set_next(self, _next):
"""Set next node of current node in source code order."""
self._next = _next
def has_same_ast(self, node: Union['Node', ast.AST]) -> bool:
"""
Check if other node holds same ast node with self.
@ -460,7 +505,7 @@ class Node:
A bool.
"""
if isinstance(node, Node):
return self.has_same_ast(node._ast_node)
return self.has_same_ast(node.get_ast())
if isinstance(node, ast.AST):
return id(self._ast_node) == id(node)
return False
@ -570,7 +615,8 @@ class Node:
keyword_map_index[keyword_ast.arg] = index
for keyword_index in range(self._kwargs_num):
key = self._normalized_args_keys[keyword_index + self._args_num]
AstModifier.update_arg_value(self._normalized_args.get(key), keywords_ast[keyword_map_index[key]].value)
AstModifier.update_arg_value(self._normalized_args.get(key),
keywords_ast[keyword_map_index.get(key)].value)
def _sync_call_method_args_to_ast(self):
"""Sync args of ast.Cell of ast.Assign from self._normalized_args when NodeType is CallMethod."""
@ -646,9 +692,9 @@ class Node:
origin_prev: Optional[Node] = self._prev
origin_next: Optional[Node] = self._next
if origin_prev is not None:
origin_prev._next = origin_next
origin_prev.set_next(origin_next)
if origin_next is not None:
origin_next._prev = origin_prev
origin_next.set_prev(origin_prev)
self._prev = None
self._next = None
@ -662,9 +708,9 @@ class Node:
node.isolate()
origin_prev: Optional[Node] = self._prev
if origin_prev is not None:
origin_prev._next = node
node._prev = origin_prev
node._next = self
origin_prev.set_next(node)
node.set_prev(origin_prev)
node.set_next(self)
self._prev = node
def insert_after(self, node: 'Node'):
@ -677,10 +723,10 @@ class Node:
node.isolate()
origin_next: Optional[Node] = self._next
self._next = node
node._prev = self
node._next = origin_next
node.set_prev(self)
node.set_next(origin_next)
if origin_next is not None:
origin_next._prev = node
origin_next.set_prev(node)
def get_inputs(self) -> ['Node']:
"""
@ -850,12 +896,12 @@ class Node:
if arg_idx >= self._args_num or arg_idx < 0:
raise RuntimeError("arg_idx out of range: ", arg_idx)
if out_idx is None:
if len(node._targets) != 1:
if len(node.get_targets()) != 1:
raise RuntimeError("node should has one output when out_idx is not provided")
out_idx = 0
if out_idx >= len(node._targets):
if out_idx >= len(node.get_targets()):
raise RuntimeError("out_idx out of range: ", out_idx)
new_arg = node._targets[out_idx]
new_arg = node.get_targets()[out_idx]
self._normalized_args[self._normalized_args_keys[arg_idx]] = new_arg
self._sync_arg()

View File

@ -0,0 +1,44 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Visit nods of SymbolTree."""
class NodeVisitor:
"""Iterator class to access SymbolTree nodes"""
def __init__(self, stree):
self._stree = stree
self._nodes = []
self._index = 0
def __iter__(self):
self._nodes = list(self._stree.get_nodes_dict().values())
self._index = 0
return self
def __next__(self):
if self._index < len(self._nodes):
node = self._nodes[self._index]
self._index += 1
return node
raise StopIteration
def append_node(self, node):
"""append new node to iterator"""
self._nodes.append(node)
def remove_node(self, node):
"""remove node of iterator"""
self._nodes.remove(node)

View File

@ -186,7 +186,6 @@ class AssignParser(Parser):
return results
def _is_subtree_cell(self, cell: Cell) -> bool:
assert isinstance(cell, Cell)
return not type(cell).__name__ in self._cell_namespce
@staticmethod
@ -353,7 +352,6 @@ class AssignParser(Parser):
# self._subnet = SubNet1(global_vars.get("subnet_args"))
# so a change in sub-network should also be identified as a change in main-network.
# so main-network should observe sub-network
new_stree.reg_observer(stree)
replacer = AstReplacer(new_stree.get_class_ast())
replacer.replace_all(new_stree.get_ori_cls_name(), new_stree.get_opt_cls_name())
return TreeNode(new_stree, father_ast_node, targets, func, call_args, call_kwargs, func_name,

View File

@ -114,7 +114,6 @@ class ClassDefParser(Parser):
def _is_subtree_field(self, ori_net, field) -> bool:
op = getattr(ori_net, field)
assert op is not None
return not type(op).__name__ in self._cell_namespace
def _process_init_func_ast(self, stree: SymbolTree, init_ast: ast.FunctionDef):

View File

@ -32,6 +32,7 @@ from .namer import TargetNamer, NodeNamer, ClassNamer
from .common.observer import Observer
from .common.observable import Observable
from .common.event import Event
from .node_visitor import NodeVisitor
class Position:
@ -113,6 +114,7 @@ class SymbolTree(Observer, Observable):
self._return: Optional[Node] = None
self._modified = False
self._node_visitor = None
def finish_build(self):
self.add_event(Event.TopologicalChangeEvent)
@ -258,25 +260,27 @@ class SymbolTree(Observer, Observable):
raise RuntimeError("Key of global_vars duplicated:", key)
self._global_vars[key] = value
def nodes(self, unfold_subtree=False):
def get_nodes_dict(self):
"""Get dict of nodes"""
return self._nodes
def nodes(self):
"""
Getter of nodes if current SymbolTree.
Args:
unfold_subtree (bool): Need to iterate into sub-symbol-tree recursively.
Returns:
A list of instance of Nodes.
"""
if unfold_subtree:
nodes = []
for _, v in self._nodes.items():
if isinstance(v, TreeNode):
nodes.extend(v.symbol_tree.nodes(unfold_subtree))
else:
nodes.append(v)
return nodes
return self._nodes.values()
if self._node_visitor is None:
self._node_visitor = NodeVisitor(self)
it = iter(self._node_visitor)
while True:
try:
n = next(it)
yield n
except StopIteration:
return None
def get_node(self, node_name: str) -> Optional[Node]:
"""
@ -431,9 +435,34 @@ class SymbolTree(Observer, Observable):
# _unique_targets must called after _update_args_for_unique and _update_kwargs_for_unique
self._unique_targets(node)
self._insert_node(position, node)
if isinstance(node, TreeNode):
node.symbol_tree.reg_observer(self)
if self._node_visitor:
self._node_visitor.append_node(node)
# update init-function-ast and construct-function-ast
if insert_to_ast:
node.set_func(ScopedValue.create_naming_value(node_name, "self"))
if isinstance(node, TreeNode):
global_vars_key = node.get_name() + "_args"
self.add_global_vars(global_vars_key, node.symbol_tree.get_global_vars())
args_call = AstModifier.create_call(ScopedValue.create_naming_value("get", "global_vars"),
[ScopedValue.create_variable_value(global_vars_key)])
value = ast.Call(func=ast.Name(node.symbol_tree.get_opt_cls_name(), ast.Store(), lineno=0,
col_offset=0), args=[args_call], keywords=[], lineno=0, col_offset=0)
ast_target = ast.Name("self." + node.get_name(), ast.Store(), lineno=0, col_offset=0)
assign = ast.Assign(targets=[ast_target], value=value, lineno=0, col_offset=0)
AstModifier.insert_assign_ast_to_function(self._init_func_ast, assign)
assign_construct = AstModifier.create_call_assign(node.get_targets(), ScopedValue.create_naming_value
(node.get_name(), "self"), node.get_args(), {})
AstModifier.insert_assign_ast_to_function(self._root_ast, assign_construct,
None if position is None else position.node.get_ast(),
position.before_node)
sub_stree: SymbolTree = node.symbol_tree
from .symbol_tree_builder import SymbolTreeBuilder
SymbolTreeBuilder.merge_module_of_subtree(self, sub_stree)
else:
node_ast = node.get_ast()
if not isinstance(node_ast, ast.Assign):
raise RuntimeError("Only support insert cell op now")
@ -595,6 +624,8 @@ class SymbolTree(Observer, Observable):
value.isolate()
break
self._topo_mgr.on_erase_node(node)
if self._node_visitor:
self._node_visitor.remove_node(node)
return node
def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node:

View File

@ -44,6 +44,22 @@ class SymbolTreeBuilder:
self._ast_root: ast.Module = ast.parse(network_str)
self._root_tree: Optional[SymbolTree] = None
@staticmethod
def merge_module_of_subtree(main_tree: SymbolTree, sub_stree: SymbolTree):
"""
Merge ast.Module of sub-network into ast.Module of main-network.
1. Merge imports of ast.Module.
2. Merge classes of ast.Module.
3. Use merged ast.Module as module of main-network and sub-network.
"""
father_mod = main_tree.get_module_ast()
sub_mod = sub_stree.get_module_ast()
SymbolTreeBuilder._merge_import_of_module(father_mod, sub_mod)
SymbolTreeBuilder._merge_class_of_module(father_mod, sub_mod)
sub_stree.set_module_ast(father_mod)
@staticmethod
def _ast_transform(ast_root: ast.AST) -> ast.AST:
"""
@ -80,7 +96,6 @@ class SymbolTreeBuilder:
main_mod_finder = AstFinder(main_mod)
imports_in_sub = copy(sub_mod_finder.find_all((ast.Import, ast.ImportFrom)))
imports_in_main = copy(main_mod_finder.find_all((ast.Import, ast.ImportFrom)))
assert imports_in_main
first_import = imports_in_main[0]
for clazz in imports_in_sub:
AstModifier.insert_sub_ast(main_mod, clazz, first_import, True)
@ -103,12 +118,11 @@ class SymbolTreeBuilder:
main_mod_finder = AstFinder(main_mod)
classes_in_sub = copy(sub_mod_finder.find_all(ast.ClassDef))
classes_in_main = copy(main_mod_finder.find_all(ast.ClassDef))
assert classes_in_main
first_class = classes_in_main[0]
for clazz in classes_in_sub:
AstModifier.insert_class_into_module(main_mod, clazz, first_class, True)
def _merge_module_of_subtree(self):
def _merge_module_of_subtrees(self):
"""
Merge ast.Module of all sub-networks into ast.Module of main-network.
@ -117,13 +131,9 @@ class SymbolTreeBuilder:
3. Use merged ast.Module as module of main-network and sub-network.
"""
father_mod = self._root_tree.get_module_ast()
for node in self._root_tree.nodes():
if isinstance(node, TreeNode):
sub_stree: SymbolTree = node.symbol_tree
SymbolTreeBuilder._merge_import_of_module(father_mod, sub_stree.get_module_ast())
SymbolTreeBuilder._merge_class_of_module(father_mod, sub_stree.get_module_ast())
sub_stree.set_module_ast(father_mod)
SymbolTreeBuilder.merge_module_of_subtree(self._root_tree, node.symbol_tree)
def _reduce_redundant_import(self):
"""
@ -140,7 +150,6 @@ class SymbolTreeBuilder:
if isinstance(body, ast.Import):
names = body.names
for name in names:
assert isinstance(name, ast.alias)
import_hash = hash((name.name, name.asname))
if import_hash in exist_import:
continue
@ -150,7 +159,6 @@ class SymbolTreeBuilder:
import_module = body.module
names = body.names
for name in names:
assert isinstance(name, ast.alias)
import_hash = hash((import_module, name.name, name.asname))
if import_hash in exist_import_from:
continue
@ -182,7 +190,7 @@ class SymbolTreeBuilder:
self._root_tree: SymbolTree = SymbolTree(self._origin_net, self._ast_root)
parser: Parser = ParserRegister.instance().get_parser(ast.Module)
parser.process(self._root_tree, self._ast_root)
self._merge_module_of_subtree()
self._merge_module_of_subtrees()
self._reduce_redundant_import()
ast.fix_missing_locations(self._root_tree.get_module_ast())
self._root_tree.finish_build()

View File

@ -15,6 +15,7 @@
from mindspore.nn import Cell, Conv2d
from mindspore.rewrite import SymbolTree
from mindspore.ops import operations as P
from .utils import get_node_by_index
class SubNet(Cell):
@ -56,7 +57,8 @@ def test_multi_targets():
"""
test_cls = NetMultiTargets()
stree = SymbolTree.create(test_cls)
node = stree.nodes()[2]
node = get_node_by_index(stree, 2)
assert node is not None
targets = node.get_targets()
assert targets[0].value == 'c1'
assert targets[1].value == 'c2'

View File

@ -121,7 +121,6 @@ def erase_node_x_11(stree: SymbolTree):
def transform(stree: SymbolTree):
add_conv_before_flatten(stree)
add_my_cell_after_x_12(stree)
erase_node_x_11(stree)

View File

@ -19,6 +19,7 @@ from mindspore.nn import Cell, Conv2d, BatchNorm2d, ReLU
from mindspore.ops import Add, AddN
from mindspore.rewrite import ScopedValue, Node, SymbolTree
from mindspore.rewrite import PatternEngine, PatternNode, Replacement, VarNode
from .utils import get_symbol_tree_nodes_count
def test_tree_pattern_match():
@ -92,13 +93,13 @@ def test_one_to_one_pattern():
assert bn is not None
assert relu1 is not None
assert len(construct_ast.body) == 6
assert len(stree.nodes()) == 7
assert get_symbol_tree_nodes_count(stree) == 7
bn_replace = BnReplace()
bn_replace.apply(stree)
assert len(construct_ast.body) == 6
assert len(stree.nodes()) == 7
assert get_symbol_tree_nodes_count(stree) == 7
conv = stree.get_node("conv")
bn = stree.get_node("bn")
relu1 = stree.get_node("relu1")
@ -167,13 +168,13 @@ def test_one_to_multi_chain_pattern():
assert bn is not None
assert relu1 is not None
assert len(construct_ast.body) == 6
assert len(stree.nodes()) == 7
assert get_symbol_tree_nodes_count(stree) == 7
bn_replace = BnReplace()
bn_replace.apply(stree)
assert len(construct_ast.body) == 7
assert len(stree.nodes()) == 8
assert get_symbol_tree_nodes_count(stree) == 8
conv = stree.get_node("conv")
bn = stree.get_node("bn")
relu1 = stree.get_node("relu1")
@ -296,13 +297,13 @@ def test_tree_pattern():
assert relu2 is not None
construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
assert len(construct_ast.body) == 8
assert len(stree.nodes()) == 9
assert get_symbol_tree_nodes_count(stree) == 9
add_relu_pattern = AddReluPattern()
add_relu_pattern.apply(stree)
assert len(construct_ast.body) == 10
assert len(stree.nodes()) == 11
assert get_symbol_tree_nodes_count(stree) == 11
conv1 = stree.get_node("conv1")
conv2 = stree.get_node("conv2")
add = stree.get_node("add")
@ -481,13 +482,13 @@ def test_multi_input_to_multi_pattern_tree_pattern():
assert relu is not None
construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
assert len(construct_ast.body) == 6
assert len(stree.nodes()) == 9
assert get_symbol_tree_nodes_count(stree) == 9
multi_input_pattern = MultiInputPattern()
multi_input_pattern.apply(stree)
assert len(construct_ast.body) == 4
assert len(stree.nodes()) == 7
assert get_symbol_tree_nodes_count(stree) == 7
conv1 = stree.get_node("conv1")
conv2 = stree.get_node("conv2")
add1 = stree.get_node("add1")
@ -598,13 +599,13 @@ def test_one_input_to_multi_pattern_tree_pattern():
assert relu is not None
construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
assert len(construct_ast.body) == 6
assert len(stree.nodes()) == 7
assert get_symbol_tree_nodes_count(stree) == 7
multi_input_pattern = MultiInputPattern()
multi_input_pattern.apply(stree)
assert len(construct_ast.body) == 4
assert len(stree.nodes()) == 5
assert get_symbol_tree_nodes_count(stree) == 5
conv1 = stree.get_node("conv1")
conv2 = stree.get_node("conv2")
add1 = stree.get_node("add1")

View File

@ -100,10 +100,52 @@ def erase_relu_in_conv2(stree: SymbolTree):
break
def inset_subtree(stree: SymbolTree):
for node in stree.nodes():
if node.get_name() == "conv2":
position = stree.before(node)
subtree = SubNet()
new_node = Node.create_call_cell(subtree, targets=[ScopedValue.create_naming_value('x')], name='conv',
args=[ScopedValue.create_naming_value('x')], kwargs={})
stree.insert(position, new_node)
break
def inset_subtree2(stree: SymbolTree):
for node in stree.nodes():
if node.get_name() == "conv2":
position = stree.before(node)
subtree = SubNet()
new_node = Node.create_call_cell(subtree, targets=[ScopedValue.create_naming_value('x')], name='conv11',
args=[ScopedValue.create_naming_value('x')], kwargs={})
stree.insert(position, new_node)
break
def add_relu_in_conv11(stree: SymbolTree):
for node in stree.nodes():
if node.get_node_type() != NodeType.Tree:
continue
if node.get_name() == "conv11":
_stree: SymbolTree = TreeNodeHelper.get_sub_tree(node)
for inner_node in _stree.nodes():
if inner_node.get_node_type() != NodeType.Output:
continue
position = _stree.before(inner_node)
new_relu = nn.ReLU()
new_relu_node = Node.create_call_cell(new_relu, targets=['x'], name='relu1',
args=[ScopedValue.create_naming_value('x')])
_stree.insert(position, new_relu_node)
_stree.set_output(0, new_relu_node.get_targets()[0].value)
break
break
def transform(stree: SymbolTree):
add_relu_in_conv1(stree)
replace_bn_in_conv2(stree)
erase_relu_in_conv2(stree)
inset_subtree(stree)
def test_subtree_net():
@ -115,7 +157,20 @@ def test_subtree_net():
net = MainNet()
stree = SymbolTree.create(net)
print(stree.get_code())
transform(stree)
for node in stree.nodes():
print("after transform node name: ", node.get_name(), "; node type: ", node.get_node_type())
if node.get_node_type() != NodeType.Tree:
continue
if node.get_name() == "conv":
modify_stree: SymbolTree = TreeNodeHelper.get_sub_tree(node)
for inner_node in modify_stree.nodes():
print("inserted subtree node: ", inner_node.get_name())
inset_subtree2(stree)
add_relu_in_conv11(stree)
print(stree.get_code())
print(stree.get_handler().get_global_vars().keys())
net_opt = stree.get_network()

View File

@ -21,6 +21,7 @@ from mindspore.rewrite import ScopedValue, ValueType, NodeType
from mindspore.rewrite import Node as NodeApi
from mindspore.rewrite.symbol_tree import SymbolTree
from mindspore.rewrite.node import Node
from .utils import get_symbol_tree_nodes_count
class Network(Cell):
@ -107,7 +108,7 @@ def test_insert_node():
consumers = getattr(getattr(stree, "_topo_mgr"), "_target_consumer")
providers_len = len(providers)
consumers_len = len(consumers)
assert len(stree.nodes()) == 7
assert get_symbol_tree_nodes_count(stree) == 7
assert len(construct_ast.body) == 6
assert len(relu1.get_targets()) == 1
assert len(relu2.get_normalized_args().values()) == 1
@ -120,7 +121,7 @@ def test_insert_node():
position = stree.before(relu2)
node = stree.insert_node(position, node)
# check nodes size
assert len(stree.nodes()) == 8
assert get_symbol_tree_nodes_count(stree) == 8
# check args
assert len(relu2.get_normalized_args().values()) == 1
assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
@ -158,7 +159,7 @@ def test_set_node_arg():
Expectation: Success.
"""
stree, bn, relu1, relu2 = create_symbol_tree()
assert len(stree.nodes()) == 7
assert get_symbol_tree_nodes_count(stree) == 7
assert len(bn.get_targets()) == 1
bn_output = bn.get_targets()[0]
# check bn topological order
@ -210,7 +211,7 @@ def test_set_node_arg_by_node():
Expectation: Success.
"""
stree, bn, relu1, relu2 = create_symbol_tree()
assert len(stree.nodes()) == 7
assert get_symbol_tree_nodes_count(stree) == 7
assert len(bn.get_targets()) == 1
bn_output = bn.get_targets()[0]
# check bn topological order
@ -265,13 +266,13 @@ def test_erase_succeed():
construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
providers = getattr(getattr(stree, "_topo_mgr"), "_target_provider")
providers_len = len(providers)
assert len(stree.nodes()) == 7
assert get_symbol_tree_nodes_count(stree) == 7
assert len(construct_ast.body) == 6
stree.set_node_arg_by_node(relu2, 0, bn)
stree.erase_node(relu1)
assert len(stree.nodes()) == 6
assert get_symbol_tree_nodes_count(stree) == 6
assert len(providers) == providers_len - 1
assert len(construct_ast.body) == 5
@ -300,13 +301,13 @@ def test_replace_one_to_one():
stree, bn, relu1, relu2 = create_symbol_tree()
construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
assert len(construct_ast.body) == 6
assert len(stree.nodes()) == 7
assert get_symbol_tree_nodes_count(stree) == 7
new_conv = Conv2d(16, 16, 5)
new_conv_node = NodeApi.create_call_cell(new_conv, [ScopedValue.create_naming_value("new_conv")],
bn.get_targets()).get_handler()
new_conv_node = stree.replace(relu1, [new_conv_node])
assert len(stree.nodes()) == 7
assert get_symbol_tree_nodes_count(stree) == 7
# check ast
assert len(construct_ast.body) == 6
node_ast: ast.Assign = construct_ast.body[2]
@ -341,7 +342,7 @@ def test_replace_one_to_multi():
stree, bn, relu1, relu2 = create_symbol_tree()
construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
assert len(construct_ast.body) == 6
assert len(stree.nodes()) == 7
assert get_symbol_tree_nodes_count(stree) == 7
new_conv_node = NodeApi.create_call_cell(Conv2d(16, 16, 5), [ScopedValue.create_naming_value("new_conv")],
bn.get_targets()).get_handler()
@ -350,7 +351,7 @@ def test_replace_one_to_multi():
new_relu_node = stree.replace(relu1, [new_relu_node, new_conv_node])
new_conv_node = new_relu_node.get_inputs()[0]
assert len(stree.nodes()) == 8
assert get_symbol_tree_nodes_count(stree) == 8
# check ast
assert len(construct_ast.body) == 7
new_conv_ast: ast.Assign = construct_ast.body[2]

View File

@ -0,0 +1,31 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.rewrite.api.symbol_tree import SymbolTree
def get_symbol_tree_nodes_count(stree: SymbolTree):
count = 0
for _ in stree.nodes():
count += 1
return count
def get_node_by_index(stree: SymbolTree, index):
for i, node in enumerate(stree.nodes()):
if i == index:
return node
return None