forked from mindspore-Ecosystem/mindspore
!34805 support const iterator For
Merge pull request !34805 from 于振华/rewrite_for_unfold_0518
This commit is contained in:
commit
d70fe9e7da
|
@ -43,6 +43,7 @@
|
||||||
"mindspore/mindspore/python/mindspore/rewrite/symbol_tree.py" "protected-access"
|
"mindspore/mindspore/python/mindspore/rewrite/symbol_tree.py" "protected-access"
|
||||||
"mindspore/mindspore/python/mindspore/rewrite/parser_register.py" "protected-access"
|
"mindspore/mindspore/python/mindspore/rewrite/parser_register.py" "protected-access"
|
||||||
"mindspore/mindspore/python/mindspore/rewrite/api/pattern_engine.py" "protected-access"
|
"mindspore/mindspore/python/mindspore/rewrite/api/pattern_engine.py" "protected-access"
|
||||||
|
"mindspore/mindspore/python/mindspore/rewrite/parsers/for_parser.py" "eval-used"
|
||||||
"mindspore/mindspore/python/mindspore/rewrite/symbol_tree.py" "inconsistent-return-statements"
|
"mindspore/mindspore/python/mindspore/rewrite/symbol_tree.py" "inconsistent-return-statements"
|
||||||
"mindspore/mindspore/python/mindspore/rewrite/parsers/if_parser.py" "eval-used"
|
"mindspore/mindspore/python/mindspore/rewrite/parsers/if_parser.py" "eval-used"
|
||||||
"mindspore/model_zoo/official/cv" "missing-docstring"
|
"mindspore/model_zoo/official/cv" "missing-docstring"
|
||||||
|
@ -120,6 +121,7 @@
|
||||||
"mindspore/tests/ut/python/rewrite/test_flatten_recursive_stmt.py" "consider-using-ternary"
|
"mindspore/tests/ut/python/rewrite/test_flatten_recursive_stmt.py" "consider-using-ternary"
|
||||||
"mindspore/tests/ut/python/rewrite/test_node.py" "syntax-error"
|
"mindspore/tests/ut/python/rewrite/test_node.py" "syntax-error"
|
||||||
"mindspore/tests/ut/python/rewrite/test_node.py" "protected-access"
|
"mindspore/tests/ut/python/rewrite/test_node.py" "protected-access"
|
||||||
|
"mindspore/tests/ut/python/rewrite/test_for.py" "protected-access"
|
||||||
"mindspore/tests/ut/python/rewrite/test_symbol_tree.py" "len-as-condition"
|
"mindspore/tests/ut/python/rewrite/test_symbol_tree.py" "len-as-condition"
|
||||||
"mindspore/tests/ut/python/rewrite/test_lenet.py" "protected-access"
|
"mindspore/tests/ut/python/rewrite/test_lenet.py" "protected-access"
|
||||||
"mindspore/tests/ut/python/rewrite/test_if.py" "protected-access"
|
"mindspore/tests/ut/python/rewrite/test_if.py" "protected-access"
|
||||||
|
|
|
@ -23,6 +23,7 @@ from .parsers.arguments_parser import g_arguments_parser
|
||||||
from .parsers.assign_parser import g_assign_parser
|
from .parsers.assign_parser import g_assign_parser
|
||||||
from .parsers.if_parser import g_if_parser
|
from .parsers.if_parser import g_if_parser
|
||||||
from .parsers.return_parser import g_return_parser
|
from .parsers.return_parser import g_return_parser
|
||||||
|
from .parsers.for_parser import g_for_parser
|
||||||
from .api.scoped_value import ScopedValue, ValueType
|
from .api.scoped_value import ScopedValue, ValueType
|
||||||
from .api.symbol_tree import SymbolTree
|
from .api.symbol_tree import SymbolTree
|
||||||
from .api.node import Node
|
from .api.node import Node
|
||||||
|
|
|
@ -50,6 +50,12 @@ class AstModifier(ast.NodeTransformer):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def erase_func_from_class_by_name(ast_class: ast.ClassDef, func_name: str):
|
||||||
|
for body in ast_class.body:
|
||||||
|
if isinstance(body, ast.FunctionDef) and body.name == func_name:
|
||||||
|
ast_class.body.remove(body)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def insert_sub_ast(ast_father: ast.AST, ast_son: ast.AST, index_ast: Optional[ast.AST] = None,
|
def insert_sub_ast(ast_father: ast.AST, ast_son: ast.AST, index_ast: Optional[ast.AST] = None,
|
||||||
insert_before=True) -> ast.AST:
|
insert_before=True) -> ast.AST:
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
""" Parse ast.For node """
|
||||||
|
import ast
|
||||||
|
import astunparse
|
||||||
|
|
||||||
|
from mindspore.rewrite.api.scoped_value import ScopedValue, ValueType
|
||||||
|
from mindspore.rewrite.ast_helpers.ast_modifier import AstModifier
|
||||||
|
from mindspore import log as logger
|
||||||
|
from ..symbol_tree import SymbolTree
|
||||||
|
from ..parser import Parser
|
||||||
|
from ..parser_register import reg_parser
|
||||||
|
from ..common.event import Event
|
||||||
|
|
||||||
|
EVAL_WHITE_LIST = ("self.", "range(", "zip(", "enumerate(", "reversed(")
|
||||||
|
|
||||||
|
|
||||||
|
class ForParser(Parser):
|
||||||
|
""" Class that implements parsing ast.For nodes """
|
||||||
|
@staticmethod
|
||||||
|
def modify_init_ast(stree, i, obj, iter_var_name):
|
||||||
|
"""Modify the ast node in init function."""
|
||||||
|
target = f"{iter_var_name.strip()}_{str(i)}"
|
||||||
|
stree.add_global_vars(target, obj)
|
||||||
|
stree.get_origin_network().insert_child_to_cell(target, obj)
|
||||||
|
AstModifier.insert_assign_to_function(stree.get_init_func_ast(),
|
||||||
|
targets=[ScopedValue(ValueType.NamingValue, "self", target)],
|
||||||
|
expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"),
|
||||||
|
args=[ScopedValue(ValueType.StringValue, "", target)])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def modify_construct_ast(stree, ast_node, old_name, new_name):
|
||||||
|
"""Modify the ast node in construct function."""
|
||||||
|
node_str: str = astunparse.unparse(ast_node)
|
||||||
|
node_str = node_str.replace(old_name, new_name)
|
||||||
|
module_node = ast.parse(node_str)
|
||||||
|
new_node = module_node.body[0]
|
||||||
|
return new_node
|
||||||
|
|
||||||
|
def target(self):
|
||||||
|
return ast.For
|
||||||
|
|
||||||
|
def process(self, stree: SymbolTree, node: ast.For):
|
||||||
|
""" Process ast.For node """
|
||||||
|
if isinstance(node.target, ast.Name):
|
||||||
|
targets = node.target.id
|
||||||
|
iter_code = astunparse.unparse(node.iter)
|
||||||
|
if not iter_code.startswith(EVAL_WHITE_LIST):
|
||||||
|
logger.warning(f"Illegal iteration condition for For node, it must start with{EVAL_WHITE_LIST}")
|
||||||
|
return
|
||||||
|
if iter_code.startswith("self"):
|
||||||
|
iter_code = iter_code.replace("self", "stree.get_origin_network()")
|
||||||
|
try:
|
||||||
|
iter_obj = eval(iter_code)
|
||||||
|
except Exception as e:
|
||||||
|
error_info = f"When eval '{iter_code}' by using JIT Fallback feature, an error occurred: {str(e)}"
|
||||||
|
logger.error(error_info)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
iter_var_name = iter_code.split(".")[-1]
|
||||||
|
index = stree.get_ast_root().body.index(node) + 1
|
||||||
|
if isinstance(iter_obj, list):
|
||||||
|
for i, obj in enumerate(iter_obj):
|
||||||
|
ForParser.modify_init_ast(stree, i, obj, iter_var_name)
|
||||||
|
for body in node.body:
|
||||||
|
new_func_name = f"self.{iter_var_name.strip()}_{str(i)}".strip()
|
||||||
|
new_node = ForParser.modify_construct_ast(stree, body, targets, new_func_name)
|
||||||
|
stree.get_ast_root().body.insert(index, new_node)
|
||||||
|
index += 1
|
||||||
|
if stree.get_ori_cls_name() == "SequentialCell":
|
||||||
|
stree.on_change(Event.CodeChangeEvent)
|
||||||
|
elif isinstance(iter_obj, range):
|
||||||
|
raise NotImplementedError("range not support")
|
||||||
|
elif isinstance(iter_obj, zip):
|
||||||
|
raise NotImplementedError("zip not support")
|
||||||
|
elif isinstance(iter_obj, enumerate):
|
||||||
|
raise NotImplementedError("enumerate not support")
|
||||||
|
else:
|
||||||
|
raise ValueError("not supported type: ", iter_obj)
|
||||||
|
|
||||||
|
g_for_parser = reg_parser(ForParser())
|
|
@ -44,6 +44,9 @@ class FunctionDefParser(Parser):
|
||||||
else:
|
else:
|
||||||
parser.process(stree, body)
|
parser.process(stree, body)
|
||||||
|
|
||||||
|
for body in node.body:
|
||||||
|
if isinstance(body, ast.For):
|
||||||
|
node.body.remove(body)
|
||||||
if hasattr(node, "decorator_list"):
|
if hasattr(node, "decorator_list"):
|
||||||
stree.try_append_python_node(node, node.decorator_list)
|
stree.try_append_python_node(node, node.decorator_list)
|
||||||
if hasattr(node, "returns"):
|
if hasattr(node, "returns"):
|
||||||
|
|
|
@ -766,8 +766,6 @@ class SymbolTree(Observer, Observable):
|
||||||
value.isolate()
|
value.isolate()
|
||||||
break
|
break
|
||||||
self._topo_mgr.on_erase_node(node)
|
self._topo_mgr.on_erase_node(node)
|
||||||
if self._node_visitor:
|
|
||||||
self._node_visitor.remove_node(node)
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def replace(self, old_node: Node, new_nodes: [Node]) -> Node:
|
def replace(self, old_node: Node, new_nodes: [Node]) -> Node:
|
||||||
|
@ -932,6 +930,9 @@ class SymbolTree(Observer, Observable):
|
||||||
body = self._module_ast.body[i]
|
body = self._module_ast.body[i]
|
||||||
if not isinstance(body, (ast.Import, ast.ImportFrom)):
|
if not isinstance(body, (ast.Import, ast.ImportFrom)):
|
||||||
continue
|
continue
|
||||||
|
if isinstance(body, ast.ImportFrom) and body.module == "cell":
|
||||||
|
self._module_ast.body.remove(body)
|
||||||
|
continue
|
||||||
for alias in body.names:
|
for alias in body.names:
|
||||||
name = alias.asname if alias.asname else alias.name
|
name = alias.asname if alias.asname else alias.name
|
||||||
if not str_checker.check(name):
|
if not str_checker.check(name):
|
||||||
|
|
|
@ -54,6 +54,8 @@ class SymbolTreeBuilder:
|
||||||
3. Use merged ast.Module as module of main-network and sub-network.
|
3. Use merged ast.Module as module of main-network and sub-network.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if sub_stree.get_ori_cls_name() == "SequentialCell":
|
||||||
|
SymbolTreeBuilder._erase_unused_func_of_sequentialcell(sub_stree.get_class_ast())
|
||||||
father_mod = main_tree.get_module_ast()
|
father_mod = main_tree.get_module_ast()
|
||||||
sub_mod = sub_stree.get_module_ast()
|
sub_mod = sub_stree.get_module_ast()
|
||||||
SymbolTreeBuilder._merge_import_of_module(father_mod, sub_mod)
|
SymbolTreeBuilder._merge_import_of_module(father_mod, sub_mod)
|
||||||
|
@ -122,6 +124,12 @@ class SymbolTreeBuilder:
|
||||||
for clazz in classes_in_sub:
|
for clazz in classes_in_sub:
|
||||||
AstModifier.insert_class_into_module(main_mod, clazz, first_class, True)
|
AstModifier.insert_class_into_module(main_mod, clazz, first_class, True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _erase_unused_func_of_sequentialcell(ast_class: ast.ClassDef):
|
||||||
|
func_names = ("__getitem__", "__setitem__", "__delitem__", "__len__", "append")
|
||||||
|
for name in func_names:
|
||||||
|
AstModifier.erase_func_from_class_by_name(ast_class, name)
|
||||||
|
|
||||||
def _merge_module_of_subtrees(self):
|
def _merge_module_of_subtrees(self):
|
||||||
"""
|
"""
|
||||||
Merge ast.Module of all sub-networks into ast.Module of main-network.
|
Merge ast.Module of all sub-networks into ast.Module of main-network.
|
||||||
|
|
|
@ -0,0 +1,322 @@
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore.rewrite import SymbolTree, PatternEngine, Replacement, PatternNode, Node, ScopedValue
|
||||||
|
from mindspore.rewrite.api.tree_node_helper import TreeNodeHelper
|
||||||
|
from mindspore.rewrite.api.node_type import NodeType
|
||||||
|
|
||||||
|
|
||||||
|
def make_layer(block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False):
|
||||||
|
"""
|
||||||
|
Make stage network of ResNet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block (Cell): Resnet block.
|
||||||
|
layer_num (int): Layer number.
|
||||||
|
in_channel (int): Input channel.
|
||||||
|
out_channel (int): Output channel.
|
||||||
|
stride (int): Stride size for the first convolutional layer.
|
||||||
|
se_block(bool): Use se block in SE-ResNet50 net. Default: False.
|
||||||
|
Returns:
|
||||||
|
SequentialCell, the output layer.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
|
||||||
|
"""
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se)
|
||||||
|
layers.append(resnet_block)
|
||||||
|
if se_block:
|
||||||
|
for _ in range(1, layer_num - 1):
|
||||||
|
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
|
||||||
|
layers.append(resnet_block)
|
||||||
|
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block)
|
||||||
|
layers.append(resnet_block)
|
||||||
|
else:
|
||||||
|
for _ in range(1, layer_num):
|
||||||
|
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
|
||||||
|
layers.append(resnet_block)
|
||||||
|
return nn.SequentialCell(layers)
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBnReplace(Replacement):
|
||||||
|
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
|
||||||
|
bn_node: Node = matched.get(pattern.name())
|
||||||
|
bn: nn.BatchNorm2d = bn_node.get_instance()
|
||||||
|
conv_p = pattern.get_inputs()[0]
|
||||||
|
conv_node: Node = matched.get(conv_p.name())
|
||||||
|
conv: nn.Conv2d = conv_node.get_instance()
|
||||||
|
newconv = nn.Conv2dBnAct(conv.in_channels,
|
||||||
|
conv.out_channels,
|
||||||
|
conv.kernel_size,
|
||||||
|
conv.stride,
|
||||||
|
conv.pad_mode,
|
||||||
|
conv.padding,
|
||||||
|
conv.dilation,
|
||||||
|
conv.group,
|
||||||
|
conv.has_bias,
|
||||||
|
conv.weight_init,
|
||||||
|
conv.bias_init,
|
||||||
|
True,
|
||||||
|
bn.momentum,
|
||||||
|
bn.eps)
|
||||||
|
newconv_node = Node.create_call_cell(newconv, bn_node.get_targets(), conv_node.get_args(),
|
||||||
|
conv_node.get_kwargs(), "Conv2dBnAct")
|
||||||
|
return [newconv_node]
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBnPattern(PatternEngine):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__([nn.Conv2d, nn.BatchNorm2d], ConvBnReplace())
|
||||||
|
|
||||||
|
|
||||||
|
class CellBlock(nn.Cell):
|
||||||
|
"""
|
||||||
|
ResNet V1 residual block definition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channel (int): Input channel.
|
||||||
|
out_channel (int): Output channel.
|
||||||
|
stride (int): Stride size for the first convolutional layer. Default: 1.
|
||||||
|
use_se (bool): Enable SE-ResNet50 net. Default: False.
|
||||||
|
se_block(bool): Use se block in SE-ResNet50 net. Default: False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, output tensor.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> ResidualBlock(3, 256, stride=2)
|
||||||
|
"""
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, in_channel, out_channel, stride=1,):
|
||||||
|
super(CellBlock, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 6, 1, stride=1)
|
||||||
|
self.bn1 = nn.BatchNorm2d(6, eps=1e-4, momentum=0.9,
|
||||||
|
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.down_sample_layer = nn.SequentialCell([nn.Conv2d(in_channel, out_channel, 1)])
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
x = self.down_sample_layer(x)
|
||||||
|
out = out + x
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ForNetWithSubTree(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(ForNetWithSubTree, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 6, 1)
|
||||||
|
self.conv2 = nn.Conv2d(6, 16, 1)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.relu1 = nn.ReLU()
|
||||||
|
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||||
|
self.max_pool2d1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||||
|
layers1 = [self.conv1, self.conv2, self.max_pool2d, self.relu]
|
||||||
|
self.layer1 = nn.SequentialCell(layers1)
|
||||||
|
|
||||||
|
resnet_block1 = CellBlock(3, 6)
|
||||||
|
resnet_block2 = CellBlock(6, 16)
|
||||||
|
resnet_block3 = CellBlock(16, 32)
|
||||||
|
layers = [resnet_block1, resnet_block2, resnet_block3]
|
||||||
|
self.layer2 = nn.SequentialCell(layers)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def test_erase_subtree_node():
|
||||||
|
"""
|
||||||
|
Feature: for parser and erase api.
|
||||||
|
Description: erase a node in subtree of `SymbolTree`.
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
net = ForNetWithSubTree()
|
||||||
|
stree = SymbolTree.create(net)
|
||||||
|
|
||||||
|
for node in stree.nodes():
|
||||||
|
if node.get_name() == "layer1":
|
||||||
|
subtree = TreeNodeHelper.get_sub_tree(node)
|
||||||
|
orig_node_num = len(subtree.get_handler()._nodes)
|
||||||
|
for n in subtree.nodes():
|
||||||
|
if n.get_instance_type() == nn.MaxPool2d:
|
||||||
|
input_node = n.get_inputs()[0]
|
||||||
|
output_nodes = n.get_users()
|
||||||
|
for out_node in output_nodes:
|
||||||
|
out_node.set_arg_by_node(0, input_node)
|
||||||
|
subtree.erase_node(n)
|
||||||
|
break
|
||||||
|
assert len(subtree.get_handler()._nodes) == orig_node_num - 1
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def test_erase_subtree_node_01():
|
||||||
|
"""
|
||||||
|
Feature: for parser and erase api.
|
||||||
|
Description: erase a node in subtree of `SymbolTree`.
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
net = ForNetWithSubTree()
|
||||||
|
stree = SymbolTree.create(net)
|
||||||
|
|
||||||
|
for node in stree.nodes():
|
||||||
|
if node.get_name() == "layer2":
|
||||||
|
subtree = TreeNodeHelper.get_sub_tree(node)
|
||||||
|
orig_node_num = len(subtree.get_handler()._nodes)
|
||||||
|
for n in subtree.nodes():
|
||||||
|
if n.get_name() == "cell_list_1":
|
||||||
|
input_node = n.get_inputs()[0]
|
||||||
|
output_nodes = n.get_users()
|
||||||
|
for _nn in output_nodes:
|
||||||
|
_nn.set_arg_by_node(0, input_node)
|
||||||
|
subtree.erase_node(n)
|
||||||
|
assert len(subtree.get_handler()._nodes) == orig_node_num - 1
|
||||||
|
break
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def test_erase_subtree_node_02():
|
||||||
|
"""
|
||||||
|
Feature: for parser and erase api.
|
||||||
|
Description: for parser and erase node in subtree of `SymbolTree`.
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
def _remove_bn(subtree):
|
||||||
|
for node in subtree.nodes():
|
||||||
|
if node.get_name() == "bn1":
|
||||||
|
input_node = node.get_inputs()[0]
|
||||||
|
output_nodes = node.get_users()
|
||||||
|
for n in output_nodes:
|
||||||
|
n.set_arg_by_node(0, input_node)
|
||||||
|
subtree.erase_node(node)
|
||||||
|
break
|
||||||
|
|
||||||
|
net = ForNetWithSubTree()
|
||||||
|
stree = SymbolTree.create(net)
|
||||||
|
for node in stree.nodes():
|
||||||
|
if node.get_name() == "layer2":
|
||||||
|
subtree = TreeNodeHelper.get_sub_tree(node)
|
||||||
|
for n in subtree.nodes():
|
||||||
|
if n.get_name() == "cell_list_1":
|
||||||
|
subtree1 = TreeNodeHelper.get_sub_tree(n)
|
||||||
|
_remove_bn(subtree1)
|
||||||
|
assert subtree1.get_node("bn1") is None
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_subtree_node():
|
||||||
|
"""
|
||||||
|
Feature: for parser and insert api.
|
||||||
|
Description: Insert node into subtree in `Symboltree`.
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
def _insert_node(subtree):
|
||||||
|
for node in subtree.nodes():
|
||||||
|
if node.get_name() == "bn1":
|
||||||
|
position = subtree.before(node)
|
||||||
|
new_conv = nn.Conv2d(16, 16, 3)
|
||||||
|
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
|
||||||
|
args=[ScopedValue.create_naming_value('self_max_po')])
|
||||||
|
subtree.insert(position, new_conv_node)
|
||||||
|
|
||||||
|
net = ForNetWithSubTree()
|
||||||
|
stree = SymbolTree.create(net)
|
||||||
|
for node in stree.nodes():
|
||||||
|
if node.get_name() == "layer2" and node.get_node_type() == NodeType.Tree:
|
||||||
|
subtree = TreeNodeHelper.get_sub_tree(node)
|
||||||
|
for n in subtree.nodes():
|
||||||
|
if n.get_name() == "cell_list_1":
|
||||||
|
subtree1 = TreeNodeHelper.get_sub_tree(n)
|
||||||
|
orig_node_num = len(subtree1.get_handler()._nodes)
|
||||||
|
_insert_node(subtree1)
|
||||||
|
assert len(subtree1.get_handler()._nodes) == orig_node_num + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_resnet_replace_121():
|
||||||
|
"""
|
||||||
|
Feature: for parser and replace api.
|
||||||
|
Description: Replace one node by one nodes in subtree of `SymbolTree`..
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
net = ForNetWithSubTree()
|
||||||
|
stree: SymbolTree = SymbolTree.create(net)
|
||||||
|
original_nodes_size = len(stree.get_handler()._nodes)
|
||||||
|
for node in stree.nodes():
|
||||||
|
if node.get_name() == "layer1" and node.get_node_type() == NodeType.Tree:
|
||||||
|
subtree = TreeNodeHelper.get_sub_tree(node)
|
||||||
|
for n in subtree.nodes():
|
||||||
|
if n.get_instance_type() == nn.Conv2d:
|
||||||
|
conv: nn.Conv2d = n.get_instance()
|
||||||
|
new_conv = Node.create_call_cell(nn.Conv2d(conv.in_channels, conv.out_channels, conv.kernel_size),
|
||||||
|
targets=n.get_targets(), args=n.get_args(),
|
||||||
|
kwargs=node.get_kwargs(), name="new_conv")
|
||||||
|
subtree.replace(n, [new_conv])
|
||||||
|
break
|
||||||
|
assert len(stree.get_handler()._nodes) == original_nodes_size
|
||||||
|
|
||||||
|
|
||||||
|
def test_resnet_replace_12m():
|
||||||
|
"""
|
||||||
|
Feature: for parser and replace api.
|
||||||
|
Description: Replace one node by multi-nodes in subtree of `SymbolTree`.
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
net = ForNetWithSubTree()
|
||||||
|
stree: SymbolTree = SymbolTree.create(net)
|
||||||
|
|
||||||
|
for node in stree.nodes():
|
||||||
|
if node.get_name() == "layer1" and node.get_node_type() == NodeType.Tree:
|
||||||
|
subtree = TreeNodeHelper.get_sub_tree(node)
|
||||||
|
original_nodes_size = len(subtree.get_handler()._nodes)
|
||||||
|
for n in subtree.nodes():
|
||||||
|
if n.get_instance_type() == nn.Conv2d:
|
||||||
|
conv: nn.Conv2d = n.get_instance()
|
||||||
|
new_conv = Node.create_call_cell(nn.Conv2d(conv.in_channels, conv.out_channels, conv.kernel_size),
|
||||||
|
targets=["x"], args=n.get_args(),
|
||||||
|
kwargs=node.get_kwargs(), name="new_conv")
|
||||||
|
new_bn = Node.create_call_cell(nn.BatchNorm2d(conv.out_channels),
|
||||||
|
targets=n.get_targets(), args=[ScopedValue.create_naming_value("x")],
|
||||||
|
kwargs={}, name="new_bn")
|
||||||
|
subtree.replace(n, [new_conv, new_bn])
|
||||||
|
break
|
||||||
|
assert len(subtree.get_handler()._nodes) == original_nodes_size + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_node_fusion_in_subtree():
|
||||||
|
"""
|
||||||
|
Feature: for parser and PatternEngine.
|
||||||
|
Description: Apply PatternEngine on nodes in `SymbolTree`..
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
net = ForNetWithSubTree()
|
||||||
|
stree: SymbolTree = SymbolTree.create(net)
|
||||||
|
original_nodes_size = len(stree.get_handler()._nodes)
|
||||||
|
for node in stree.nodes():
|
||||||
|
if node.get_name() == "layer1" and node.get_node_type() == NodeType.Tree:
|
||||||
|
subtree = TreeNodeHelper.get_sub_tree(node)
|
||||||
|
original_nodes_size = len(subtree.get_handler()._nodes)
|
||||||
|
for n in subtree.nodes():
|
||||||
|
node_: Node = n
|
||||||
|
if node_.get_instance_type() == nn.Conv2d:
|
||||||
|
old_bn = node_.get_users()[0]
|
||||||
|
pos = subtree.after(node_)
|
||||||
|
conv: nn.Conv2d = node_.get_instance()
|
||||||
|
new_bn = Node.create_call_cell(nn.BatchNorm2d(conv.out_channels), targets=["x"],
|
||||||
|
args=[node_.get_targets()[0]], kwargs={}, name="new_bn")
|
||||||
|
subtree.insert(pos, new_bn)
|
||||||
|
old_bn.set_arg_by_node(0, new_bn)
|
||||||
|
break
|
||||||
|
assert len(subtree.get_handler()._nodes) == original_nodes_size + 1
|
||||||
|
ConvBnPattern().apply(subtree)
|
||||||
|
assert len(subtree.get_handler()._nodes) == original_nodes_size
|
||||||
|
assert not subtree.get_node("conv1")
|
||||||
|
assert not subtree.get_node("new_bn")
|
Loading…
Reference in New Issue