forked from mindspore-Ecosystem/mindspore
!34805 support const iterator For
Merge pull request !34805 from 于振华/rewrite_for_unfold_0518
This commit is contained in:
commit
d70fe9e7da
|
@ -43,6 +43,7 @@
|
|||
"mindspore/mindspore/python/mindspore/rewrite/symbol_tree.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/rewrite/parser_register.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/rewrite/api/pattern_engine.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/rewrite/parsers/for_parser.py" "eval-used"
|
||||
"mindspore/mindspore/python/mindspore/rewrite/symbol_tree.py" "inconsistent-return-statements"
|
||||
"mindspore/mindspore/python/mindspore/rewrite/parsers/if_parser.py" "eval-used"
|
||||
"mindspore/model_zoo/official/cv" "missing-docstring"
|
||||
|
@ -120,6 +121,7 @@
|
|||
"mindspore/tests/ut/python/rewrite/test_flatten_recursive_stmt.py" "consider-using-ternary"
|
||||
"mindspore/tests/ut/python/rewrite/test_node.py" "syntax-error"
|
||||
"mindspore/tests/ut/python/rewrite/test_node.py" "protected-access"
|
||||
"mindspore/tests/ut/python/rewrite/test_for.py" "protected-access"
|
||||
"mindspore/tests/ut/python/rewrite/test_symbol_tree.py" "len-as-condition"
|
||||
"mindspore/tests/ut/python/rewrite/test_lenet.py" "protected-access"
|
||||
"mindspore/tests/ut/python/rewrite/test_if.py" "protected-access"
|
||||
|
|
|
@ -23,6 +23,7 @@ from .parsers.arguments_parser import g_arguments_parser
|
|||
from .parsers.assign_parser import g_assign_parser
|
||||
from .parsers.if_parser import g_if_parser
|
||||
from .parsers.return_parser import g_return_parser
|
||||
from .parsers.for_parser import g_for_parser
|
||||
from .api.scoped_value import ScopedValue, ValueType
|
||||
from .api.symbol_tree import SymbolTree
|
||||
from .api.node import Node
|
||||
|
|
|
@ -50,6 +50,12 @@ class AstModifier(ast.NodeTransformer):
|
|||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def erase_func_from_class_by_name(ast_class: ast.ClassDef, func_name: str):
|
||||
for body in ast_class.body:
|
||||
if isinstance(body, ast.FunctionDef) and body.name == func_name:
|
||||
ast_class.body.remove(body)
|
||||
|
||||
@staticmethod
|
||||
def insert_sub_ast(ast_father: ast.AST, ast_son: ast.AST, index_ast: Optional[ast.AST] = None,
|
||||
insert_before=True) -> ast.AST:
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" Parse ast.For node """
|
||||
import ast
|
||||
import astunparse
|
||||
|
||||
from mindspore.rewrite.api.scoped_value import ScopedValue, ValueType
|
||||
from mindspore.rewrite.ast_helpers.ast_modifier import AstModifier
|
||||
from mindspore import log as logger
|
||||
from ..symbol_tree import SymbolTree
|
||||
from ..parser import Parser
|
||||
from ..parser_register import reg_parser
|
||||
from ..common.event import Event
|
||||
|
||||
EVAL_WHITE_LIST = ("self.", "range(", "zip(", "enumerate(", "reversed(")
|
||||
|
||||
|
||||
class ForParser(Parser):
|
||||
""" Class that implements parsing ast.For nodes """
|
||||
@staticmethod
|
||||
def modify_init_ast(stree, i, obj, iter_var_name):
|
||||
"""Modify the ast node in init function."""
|
||||
target = f"{iter_var_name.strip()}_{str(i)}"
|
||||
stree.add_global_vars(target, obj)
|
||||
stree.get_origin_network().insert_child_to_cell(target, obj)
|
||||
AstModifier.insert_assign_to_function(stree.get_init_func_ast(),
|
||||
targets=[ScopedValue(ValueType.NamingValue, "self", target)],
|
||||
expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"),
|
||||
args=[ScopedValue(ValueType.StringValue, "", target)])
|
||||
|
||||
@staticmethod
|
||||
def modify_construct_ast(stree, ast_node, old_name, new_name):
|
||||
"""Modify the ast node in construct function."""
|
||||
node_str: str = astunparse.unparse(ast_node)
|
||||
node_str = node_str.replace(old_name, new_name)
|
||||
module_node = ast.parse(node_str)
|
||||
new_node = module_node.body[0]
|
||||
return new_node
|
||||
|
||||
def target(self):
|
||||
return ast.For
|
||||
|
||||
def process(self, stree: SymbolTree, node: ast.For):
|
||||
""" Process ast.For node """
|
||||
if isinstance(node.target, ast.Name):
|
||||
targets = node.target.id
|
||||
iter_code = astunparse.unparse(node.iter)
|
||||
if not iter_code.startswith(EVAL_WHITE_LIST):
|
||||
logger.warning(f"Illegal iteration condition for For node, it must start with{EVAL_WHITE_LIST}")
|
||||
return
|
||||
if iter_code.startswith("self"):
|
||||
iter_code = iter_code.replace("self", "stree.get_origin_network()")
|
||||
try:
|
||||
iter_obj = eval(iter_code)
|
||||
except Exception as e:
|
||||
error_info = f"When eval '{iter_code}' by using JIT Fallback feature, an error occurred: {str(e)}"
|
||||
logger.error(error_info)
|
||||
raise e
|
||||
|
||||
iter_var_name = iter_code.split(".")[-1]
|
||||
index = stree.get_ast_root().body.index(node) + 1
|
||||
if isinstance(iter_obj, list):
|
||||
for i, obj in enumerate(iter_obj):
|
||||
ForParser.modify_init_ast(stree, i, obj, iter_var_name)
|
||||
for body in node.body:
|
||||
new_func_name = f"self.{iter_var_name.strip()}_{str(i)}".strip()
|
||||
new_node = ForParser.modify_construct_ast(stree, body, targets, new_func_name)
|
||||
stree.get_ast_root().body.insert(index, new_node)
|
||||
index += 1
|
||||
if stree.get_ori_cls_name() == "SequentialCell":
|
||||
stree.on_change(Event.CodeChangeEvent)
|
||||
elif isinstance(iter_obj, range):
|
||||
raise NotImplementedError("range not support")
|
||||
elif isinstance(iter_obj, zip):
|
||||
raise NotImplementedError("zip not support")
|
||||
elif isinstance(iter_obj, enumerate):
|
||||
raise NotImplementedError("enumerate not support")
|
||||
else:
|
||||
raise ValueError("not supported type: ", iter_obj)
|
||||
|
||||
g_for_parser = reg_parser(ForParser())
|
|
@ -44,6 +44,9 @@ class FunctionDefParser(Parser):
|
|||
else:
|
||||
parser.process(stree, body)
|
||||
|
||||
for body in node.body:
|
||||
if isinstance(body, ast.For):
|
||||
node.body.remove(body)
|
||||
if hasattr(node, "decorator_list"):
|
||||
stree.try_append_python_node(node, node.decorator_list)
|
||||
if hasattr(node, "returns"):
|
||||
|
|
|
@ -766,8 +766,6 @@ class SymbolTree(Observer, Observable):
|
|||
value.isolate()
|
||||
break
|
||||
self._topo_mgr.on_erase_node(node)
|
||||
if self._node_visitor:
|
||||
self._node_visitor.remove_node(node)
|
||||
return node
|
||||
|
||||
def replace(self, old_node: Node, new_nodes: [Node]) -> Node:
|
||||
|
@ -932,6 +930,9 @@ class SymbolTree(Observer, Observable):
|
|||
body = self._module_ast.body[i]
|
||||
if not isinstance(body, (ast.Import, ast.ImportFrom)):
|
||||
continue
|
||||
if isinstance(body, ast.ImportFrom) and body.module == "cell":
|
||||
self._module_ast.body.remove(body)
|
||||
continue
|
||||
for alias in body.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
if not str_checker.check(name):
|
||||
|
|
|
@ -54,6 +54,8 @@ class SymbolTreeBuilder:
|
|||
3. Use merged ast.Module as module of main-network and sub-network.
|
||||
"""
|
||||
|
||||
if sub_stree.get_ori_cls_name() == "SequentialCell":
|
||||
SymbolTreeBuilder._erase_unused_func_of_sequentialcell(sub_stree.get_class_ast())
|
||||
father_mod = main_tree.get_module_ast()
|
||||
sub_mod = sub_stree.get_module_ast()
|
||||
SymbolTreeBuilder._merge_import_of_module(father_mod, sub_mod)
|
||||
|
@ -122,6 +124,12 @@ class SymbolTreeBuilder:
|
|||
for clazz in classes_in_sub:
|
||||
AstModifier.insert_class_into_module(main_mod, clazz, first_class, True)
|
||||
|
||||
@staticmethod
|
||||
def _erase_unused_func_of_sequentialcell(ast_class: ast.ClassDef):
|
||||
func_names = ("__getitem__", "__setitem__", "__delitem__", "__len__", "append")
|
||||
for name in func_names:
|
||||
AstModifier.erase_func_from_class_by_name(ast_class, name)
|
||||
|
||||
def _merge_module_of_subtrees(self):
|
||||
"""
|
||||
Merge ast.Module of all sub-networks into ast.Module of main-network.
|
||||
|
|
|
@ -0,0 +1,322 @@
|
|||
|
||||
from collections import OrderedDict
|
||||
|
||||
from mindspore import nn
|
||||
from mindspore.rewrite import SymbolTree, PatternEngine, Replacement, PatternNode, Node, ScopedValue
|
||||
from mindspore.rewrite.api.tree_node_helper import TreeNodeHelper
|
||||
from mindspore.rewrite.api.node_type import NodeType
|
||||
|
||||
|
||||
def make_layer(block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False):
|
||||
"""
|
||||
Make stage network of ResNet.
|
||||
|
||||
Args:
|
||||
block (Cell): Resnet block.
|
||||
layer_num (int): Layer number.
|
||||
in_channel (int): Input channel.
|
||||
out_channel (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer.
|
||||
se_block(bool): Use se block in SE-ResNet50 net. Default: False.
|
||||
Returns:
|
||||
SequentialCell, the output layer.
|
||||
|
||||
Examples:
|
||||
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
|
||||
"""
|
||||
layers = []
|
||||
|
||||
resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se)
|
||||
layers.append(resnet_block)
|
||||
if se_block:
|
||||
for _ in range(1, layer_num - 1):
|
||||
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
|
||||
layers.append(resnet_block)
|
||||
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block)
|
||||
layers.append(resnet_block)
|
||||
else:
|
||||
for _ in range(1, layer_num):
|
||||
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
|
||||
layers.append(resnet_block)
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
|
||||
class ConvBnReplace(Replacement):
|
||||
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
|
||||
bn_node: Node = matched.get(pattern.name())
|
||||
bn: nn.BatchNorm2d = bn_node.get_instance()
|
||||
conv_p = pattern.get_inputs()[0]
|
||||
conv_node: Node = matched.get(conv_p.name())
|
||||
conv: nn.Conv2d = conv_node.get_instance()
|
||||
newconv = nn.Conv2dBnAct(conv.in_channels,
|
||||
conv.out_channels,
|
||||
conv.kernel_size,
|
||||
conv.stride,
|
||||
conv.pad_mode,
|
||||
conv.padding,
|
||||
conv.dilation,
|
||||
conv.group,
|
||||
conv.has_bias,
|
||||
conv.weight_init,
|
||||
conv.bias_init,
|
||||
True,
|
||||
bn.momentum,
|
||||
bn.eps)
|
||||
newconv_node = Node.create_call_cell(newconv, bn_node.get_targets(), conv_node.get_args(),
|
||||
conv_node.get_kwargs(), "Conv2dBnAct")
|
||||
return [newconv_node]
|
||||
|
||||
|
||||
class ConvBnPattern(PatternEngine):
|
||||
def __init__(self):
|
||||
super().__init__([nn.Conv2d, nn.BatchNorm2d], ConvBnReplace())
|
||||
|
||||
|
||||
class CellBlock(nn.Cell):
|
||||
"""
|
||||
ResNet V1 residual block definition.
|
||||
|
||||
Args:
|
||||
in_channel (int): Input channel.
|
||||
out_channel (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer. Default: 1.
|
||||
use_se (bool): Enable SE-ResNet50 net. Default: False.
|
||||
se_block(bool): Use se block in SE-ResNet50 net. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResidualBlock(3, 256, stride=2)
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_channel, out_channel, stride=1,):
|
||||
super(CellBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 6, 1, stride=1)
|
||||
self.bn1 = nn.BatchNorm2d(6, eps=1e-4, momentum=0.9,
|
||||
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
self.relu = nn.ReLU()
|
||||
self.down_sample_layer = nn.SequentialCell([nn.Conv2d(in_channel, out_channel, 1)])
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
x = self.down_sample_layer(x)
|
||||
out = out + x
|
||||
return out
|
||||
|
||||
|
||||
class ForNetWithSubTree(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ForNetWithSubTree, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 6, 1)
|
||||
self.conv2 = nn.Conv2d(6, 16, 1)
|
||||
self.relu = nn.ReLU()
|
||||
self.relu1 = nn.ReLU()
|
||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.max_pool2d1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
layers1 = [self.conv1, self.conv2, self.max_pool2d, self.relu]
|
||||
self.layer1 = nn.SequentialCell(layers1)
|
||||
|
||||
resnet_block1 = CellBlock(3, 6)
|
||||
resnet_block2 = CellBlock(6, 16)
|
||||
resnet_block3 = CellBlock(16, 32)
|
||||
layers = [resnet_block1, resnet_block2, resnet_block3]
|
||||
self.layer2 = nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.layer1(x)
|
||||
x = self.relu(x)
|
||||
x = self.layer2(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_erase_subtree_node():
|
||||
"""
|
||||
Feature: for parser and erase api.
|
||||
Description: erase a node in subtree of `SymbolTree`.
|
||||
Expectation: Success.
|
||||
"""
|
||||
net = ForNetWithSubTree()
|
||||
stree = SymbolTree.create(net)
|
||||
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "layer1":
|
||||
subtree = TreeNodeHelper.get_sub_tree(node)
|
||||
orig_node_num = len(subtree.get_handler()._nodes)
|
||||
for n in subtree.nodes():
|
||||
if n.get_instance_type() == nn.MaxPool2d:
|
||||
input_node = n.get_inputs()[0]
|
||||
output_nodes = n.get_users()
|
||||
for out_node in output_nodes:
|
||||
out_node.set_arg_by_node(0, input_node)
|
||||
subtree.erase_node(n)
|
||||
break
|
||||
assert len(subtree.get_handler()._nodes) == orig_node_num - 1
|
||||
break
|
||||
|
||||
|
||||
def test_erase_subtree_node_01():
|
||||
"""
|
||||
Feature: for parser and erase api.
|
||||
Description: erase a node in subtree of `SymbolTree`.
|
||||
Expectation: Success.
|
||||
"""
|
||||
net = ForNetWithSubTree()
|
||||
stree = SymbolTree.create(net)
|
||||
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "layer2":
|
||||
subtree = TreeNodeHelper.get_sub_tree(node)
|
||||
orig_node_num = len(subtree.get_handler()._nodes)
|
||||
for n in subtree.nodes():
|
||||
if n.get_name() == "cell_list_1":
|
||||
input_node = n.get_inputs()[0]
|
||||
output_nodes = n.get_users()
|
||||
for _nn in output_nodes:
|
||||
_nn.set_arg_by_node(0, input_node)
|
||||
subtree.erase_node(n)
|
||||
assert len(subtree.get_handler()._nodes) == orig_node_num - 1
|
||||
break
|
||||
break
|
||||
|
||||
|
||||
def test_erase_subtree_node_02():
|
||||
"""
|
||||
Feature: for parser and erase api.
|
||||
Description: for parser and erase node in subtree of `SymbolTree`.
|
||||
Expectation: Success.
|
||||
"""
|
||||
def _remove_bn(subtree):
|
||||
for node in subtree.nodes():
|
||||
if node.get_name() == "bn1":
|
||||
input_node = node.get_inputs()[0]
|
||||
output_nodes = node.get_users()
|
||||
for n in output_nodes:
|
||||
n.set_arg_by_node(0, input_node)
|
||||
subtree.erase_node(node)
|
||||
break
|
||||
|
||||
net = ForNetWithSubTree()
|
||||
stree = SymbolTree.create(net)
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "layer2":
|
||||
subtree = TreeNodeHelper.get_sub_tree(node)
|
||||
for n in subtree.nodes():
|
||||
if n.get_name() == "cell_list_1":
|
||||
subtree1 = TreeNodeHelper.get_sub_tree(n)
|
||||
_remove_bn(subtree1)
|
||||
assert subtree1.get_node("bn1") is None
|
||||
break
|
||||
|
||||
|
||||
def test_insert_subtree_node():
|
||||
"""
|
||||
Feature: for parser and insert api.
|
||||
Description: Insert node into subtree in `Symboltree`.
|
||||
Expectation: Success.
|
||||
"""
|
||||
def _insert_node(subtree):
|
||||
for node in subtree.nodes():
|
||||
if node.get_name() == "bn1":
|
||||
position = subtree.before(node)
|
||||
new_conv = nn.Conv2d(16, 16, 3)
|
||||
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
|
||||
args=[ScopedValue.create_naming_value('self_max_po')])
|
||||
subtree.insert(position, new_conv_node)
|
||||
|
||||
net = ForNetWithSubTree()
|
||||
stree = SymbolTree.create(net)
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "layer2" and node.get_node_type() == NodeType.Tree:
|
||||
subtree = TreeNodeHelper.get_sub_tree(node)
|
||||
for n in subtree.nodes():
|
||||
if n.get_name() == "cell_list_1":
|
||||
subtree1 = TreeNodeHelper.get_sub_tree(n)
|
||||
orig_node_num = len(subtree1.get_handler()._nodes)
|
||||
_insert_node(subtree1)
|
||||
assert len(subtree1.get_handler()._nodes) == orig_node_num + 1
|
||||
|
||||
|
||||
def test_resnet_replace_121():
|
||||
"""
|
||||
Feature: for parser and replace api.
|
||||
Description: Replace one node by one nodes in subtree of `SymbolTree`..
|
||||
Expectation: Success.
|
||||
"""
|
||||
net = ForNetWithSubTree()
|
||||
stree: SymbolTree = SymbolTree.create(net)
|
||||
original_nodes_size = len(stree.get_handler()._nodes)
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "layer1" and node.get_node_type() == NodeType.Tree:
|
||||
subtree = TreeNodeHelper.get_sub_tree(node)
|
||||
for n in subtree.nodes():
|
||||
if n.get_instance_type() == nn.Conv2d:
|
||||
conv: nn.Conv2d = n.get_instance()
|
||||
new_conv = Node.create_call_cell(nn.Conv2d(conv.in_channels, conv.out_channels, conv.kernel_size),
|
||||
targets=n.get_targets(), args=n.get_args(),
|
||||
kwargs=node.get_kwargs(), name="new_conv")
|
||||
subtree.replace(n, [new_conv])
|
||||
break
|
||||
assert len(stree.get_handler()._nodes) == original_nodes_size
|
||||
|
||||
|
||||
def test_resnet_replace_12m():
|
||||
"""
|
||||
Feature: for parser and replace api.
|
||||
Description: Replace one node by multi-nodes in subtree of `SymbolTree`.
|
||||
Expectation: Success.
|
||||
"""
|
||||
net = ForNetWithSubTree()
|
||||
stree: SymbolTree = SymbolTree.create(net)
|
||||
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "layer1" and node.get_node_type() == NodeType.Tree:
|
||||
subtree = TreeNodeHelper.get_sub_tree(node)
|
||||
original_nodes_size = len(subtree.get_handler()._nodes)
|
||||
for n in subtree.nodes():
|
||||
if n.get_instance_type() == nn.Conv2d:
|
||||
conv: nn.Conv2d = n.get_instance()
|
||||
new_conv = Node.create_call_cell(nn.Conv2d(conv.in_channels, conv.out_channels, conv.kernel_size),
|
||||
targets=["x"], args=n.get_args(),
|
||||
kwargs=node.get_kwargs(), name="new_conv")
|
||||
new_bn = Node.create_call_cell(nn.BatchNorm2d(conv.out_channels),
|
||||
targets=n.get_targets(), args=[ScopedValue.create_naming_value("x")],
|
||||
kwargs={}, name="new_bn")
|
||||
subtree.replace(n, [new_conv, new_bn])
|
||||
break
|
||||
assert len(subtree.get_handler()._nodes) == original_nodes_size + 1
|
||||
|
||||
|
||||
def test_node_fusion_in_subtree():
|
||||
"""
|
||||
Feature: for parser and PatternEngine.
|
||||
Description: Apply PatternEngine on nodes in `SymbolTree`..
|
||||
Expectation: Success.
|
||||
"""
|
||||
net = ForNetWithSubTree()
|
||||
stree: SymbolTree = SymbolTree.create(net)
|
||||
original_nodes_size = len(stree.get_handler()._nodes)
|
||||
for node in stree.nodes():
|
||||
if node.get_name() == "layer1" and node.get_node_type() == NodeType.Tree:
|
||||
subtree = TreeNodeHelper.get_sub_tree(node)
|
||||
original_nodes_size = len(subtree.get_handler()._nodes)
|
||||
for n in subtree.nodes():
|
||||
node_: Node = n
|
||||
if node_.get_instance_type() == nn.Conv2d:
|
||||
old_bn = node_.get_users()[0]
|
||||
pos = subtree.after(node_)
|
||||
conv: nn.Conv2d = node_.get_instance()
|
||||
new_bn = Node.create_call_cell(nn.BatchNorm2d(conv.out_channels), targets=["x"],
|
||||
args=[node_.get_targets()[0]], kwargs={}, name="new_bn")
|
||||
subtree.insert(pos, new_bn)
|
||||
old_bn.set_arg_by_node(0, new_bn)
|
||||
break
|
||||
assert len(subtree.get_handler()._nodes) == original_nodes_size + 1
|
||||
ConvBnPattern().apply(subtree)
|
||||
assert len(subtree.get_handler()._nodes) == original_nodes_size
|
||||
assert not subtree.get_node("conv1")
|
||||
assert not subtree.get_node("new_bn")
|
Loading…
Reference in New Issue