diff --git a/.jenkins/check/config/filter_pylint.txt b/.jenkins/check/config/filter_pylint.txt index f404c458e4d..002faa0240e 100644 --- a/.jenkins/check/config/filter_pylint.txt +++ b/.jenkins/check/config/filter_pylint.txt @@ -43,6 +43,7 @@ "mindspore/mindspore/python/mindspore/rewrite/symbol_tree.py" "protected-access" "mindspore/mindspore/python/mindspore/rewrite/parser_register.py" "protected-access" "mindspore/mindspore/python/mindspore/rewrite/api/pattern_engine.py" "protected-access" +"mindspore/mindspore/python/mindspore/rewrite/parsers/for_parser.py" "eval-used" "mindspore/mindspore/python/mindspore/rewrite/symbol_tree.py" "inconsistent-return-statements" "mindspore/mindspore/python/mindspore/rewrite/parsers/if_parser.py" "eval-used" "mindspore/model_zoo/official/cv" "missing-docstring" @@ -120,6 +121,7 @@ "mindspore/tests/ut/python/rewrite/test_flatten_recursive_stmt.py" "consider-using-ternary" "mindspore/tests/ut/python/rewrite/test_node.py" "syntax-error" "mindspore/tests/ut/python/rewrite/test_node.py" "protected-access" +"mindspore/tests/ut/python/rewrite/test_for.py" "protected-access" "mindspore/tests/ut/python/rewrite/test_symbol_tree.py" "len-as-condition" "mindspore/tests/ut/python/rewrite/test_lenet.py" "protected-access" "mindspore/tests/ut/python/rewrite/test_if.py" "protected-access" diff --git a/mindspore/python/mindspore/rewrite/__init__.py b/mindspore/python/mindspore/rewrite/__init__.py index d0d01412591..f5cf7e5d03e 100644 --- a/mindspore/python/mindspore/rewrite/__init__.py +++ b/mindspore/python/mindspore/rewrite/__init__.py @@ -23,6 +23,7 @@ from .parsers.arguments_parser import g_arguments_parser from .parsers.assign_parser import g_assign_parser from .parsers.if_parser import g_if_parser from .parsers.return_parser import g_return_parser +from .parsers.for_parser import g_for_parser from .api.scoped_value import ScopedValue, ValueType from .api.symbol_tree import SymbolTree from .api.node import Node diff --git a/mindspore/python/mindspore/rewrite/ast_helpers/ast_modifier.py b/mindspore/python/mindspore/rewrite/ast_helpers/ast_modifier.py index a2da2bdc6f4..02e4eed907e 100644 --- a/mindspore/python/mindspore/rewrite/ast_helpers/ast_modifier.py +++ b/mindspore/python/mindspore/rewrite/ast_helpers/ast_modifier.py @@ -50,6 +50,12 @@ class AstModifier(ast.NodeTransformer): return True return False + @staticmethod + def erase_func_from_class_by_name(ast_class: ast.ClassDef, func_name: str): + for body in ast_class.body: + if isinstance(body, ast.FunctionDef) and body.name == func_name: + ast_class.body.remove(body) + @staticmethod def insert_sub_ast(ast_father: ast.AST, ast_son: ast.AST, index_ast: Optional[ast.AST] = None, insert_before=True) -> ast.AST: diff --git a/mindspore/python/mindspore/rewrite/parsers/for_parser.py b/mindspore/python/mindspore/rewrite/parsers/for_parser.py new file mode 100644 index 00000000000..f1d7ad60424 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/parsers/for_parser.py @@ -0,0 +1,93 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" Parse ast.For node """ +import ast +import astunparse + +from mindspore.rewrite.api.scoped_value import ScopedValue, ValueType +from mindspore.rewrite.ast_helpers.ast_modifier import AstModifier +from mindspore import log as logger +from ..symbol_tree import SymbolTree +from ..parser import Parser +from ..parser_register import reg_parser +from ..common.event import Event + +EVAL_WHITE_LIST = ("self.", "range(", "zip(", "enumerate(", "reversed(") + + +class ForParser(Parser): + """ Class that implements parsing ast.For nodes """ + @staticmethod + def modify_init_ast(stree, i, obj, iter_var_name): + """Modify the ast node in init function.""" + target = f"{iter_var_name.strip()}_{str(i)}" + stree.add_global_vars(target, obj) + stree.get_origin_network().insert_child_to_cell(target, obj) + AstModifier.insert_assign_to_function(stree.get_init_func_ast(), + targets=[ScopedValue(ValueType.NamingValue, "self", target)], + expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"), + args=[ScopedValue(ValueType.StringValue, "", target)]) + + @staticmethod + def modify_construct_ast(stree, ast_node, old_name, new_name): + """Modify the ast node in construct function.""" + node_str: str = astunparse.unparse(ast_node) + node_str = node_str.replace(old_name, new_name) + module_node = ast.parse(node_str) + new_node = module_node.body[0] + return new_node + + def target(self): + return ast.For + + def process(self, stree: SymbolTree, node: ast.For): + """ Process ast.For node """ + if isinstance(node.target, ast.Name): + targets = node.target.id + iter_code = astunparse.unparse(node.iter) + if not iter_code.startswith(EVAL_WHITE_LIST): + logger.warning(f"Illegal iteration condition for For node, it must start with{EVAL_WHITE_LIST}") + return + if iter_code.startswith("self"): + iter_code = iter_code.replace("self", "stree.get_origin_network()") + try: + iter_obj = eval(iter_code) + except Exception as e: + error_info = f"When eval '{iter_code}' by using JIT Fallback feature, an error occurred: {str(e)}" + logger.error(error_info) + raise e + + iter_var_name = iter_code.split(".")[-1] + index = stree.get_ast_root().body.index(node) + 1 + if isinstance(iter_obj, list): + for i, obj in enumerate(iter_obj): + ForParser.modify_init_ast(stree, i, obj, iter_var_name) + for body in node.body: + new_func_name = f"self.{iter_var_name.strip()}_{str(i)}".strip() + new_node = ForParser.modify_construct_ast(stree, body, targets, new_func_name) + stree.get_ast_root().body.insert(index, new_node) + index += 1 + if stree.get_ori_cls_name() == "SequentialCell": + stree.on_change(Event.CodeChangeEvent) + elif isinstance(iter_obj, range): + raise NotImplementedError("range not support") + elif isinstance(iter_obj, zip): + raise NotImplementedError("zip not support") + elif isinstance(iter_obj, enumerate): + raise NotImplementedError("enumerate not support") + else: + raise ValueError("not supported type: ", iter_obj) + +g_for_parser = reg_parser(ForParser()) diff --git a/mindspore/python/mindspore/rewrite/parsers/function_def_parser.py b/mindspore/python/mindspore/rewrite/parsers/function_def_parser.py index fbde74543dc..437c0e0829b 100644 --- a/mindspore/python/mindspore/rewrite/parsers/function_def_parser.py +++ b/mindspore/python/mindspore/rewrite/parsers/function_def_parser.py @@ -44,6 +44,9 @@ class FunctionDefParser(Parser): else: parser.process(stree, body) + for body in node.body: + if isinstance(body, ast.For): + node.body.remove(body) if hasattr(node, "decorator_list"): stree.try_append_python_node(node, node.decorator_list) if hasattr(node, "returns"): diff --git a/mindspore/python/mindspore/rewrite/symbol_tree.py b/mindspore/python/mindspore/rewrite/symbol_tree.py index 18f24d0c43a..9247eddbb21 100644 --- a/mindspore/python/mindspore/rewrite/symbol_tree.py +++ b/mindspore/python/mindspore/rewrite/symbol_tree.py @@ -766,8 +766,6 @@ class SymbolTree(Observer, Observable): value.isolate() break self._topo_mgr.on_erase_node(node) - if self._node_visitor: - self._node_visitor.remove_node(node) return node def replace(self, old_node: Node, new_nodes: [Node]) -> Node: @@ -932,6 +930,9 @@ class SymbolTree(Observer, Observable): body = self._module_ast.body[i] if not isinstance(body, (ast.Import, ast.ImportFrom)): continue + if isinstance(body, ast.ImportFrom) and body.module == "cell": + self._module_ast.body.remove(body) + continue for alias in body.names: name = alias.asname if alias.asname else alias.name if not str_checker.check(name): diff --git a/mindspore/python/mindspore/rewrite/symbol_tree_builder.py b/mindspore/python/mindspore/rewrite/symbol_tree_builder.py index f831ca4941d..6fce24c6eb0 100644 --- a/mindspore/python/mindspore/rewrite/symbol_tree_builder.py +++ b/mindspore/python/mindspore/rewrite/symbol_tree_builder.py @@ -54,6 +54,8 @@ class SymbolTreeBuilder: 3. Use merged ast.Module as module of main-network and sub-network. """ + if sub_stree.get_ori_cls_name() == "SequentialCell": + SymbolTreeBuilder._erase_unused_func_of_sequentialcell(sub_stree.get_class_ast()) father_mod = main_tree.get_module_ast() sub_mod = sub_stree.get_module_ast() SymbolTreeBuilder._merge_import_of_module(father_mod, sub_mod) @@ -122,6 +124,12 @@ class SymbolTreeBuilder: for clazz in classes_in_sub: AstModifier.insert_class_into_module(main_mod, clazz, first_class, True) + @staticmethod + def _erase_unused_func_of_sequentialcell(ast_class: ast.ClassDef): + func_names = ("__getitem__", "__setitem__", "__delitem__", "__len__", "append") + for name in func_names: + AstModifier.erase_func_from_class_by_name(ast_class, name) + def _merge_module_of_subtrees(self): """ Merge ast.Module of all sub-networks into ast.Module of main-network. diff --git a/tests/ut/python/rewrite/test_for.py b/tests/ut/python/rewrite/test_for.py new file mode 100644 index 00000000000..0b6059b68b5 --- /dev/null +++ b/tests/ut/python/rewrite/test_for.py @@ -0,0 +1,322 @@ + +from collections import OrderedDict + +from mindspore import nn +from mindspore.rewrite import SymbolTree, PatternEngine, Replacement, PatternNode, Node, ScopedValue +from mindspore.rewrite.api.tree_node_helper import TreeNodeHelper +from mindspore.rewrite.api.node_type import NodeType + + +def make_layer(block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False): + """ + Make stage network of ResNet. + + Args: + block (Cell): Resnet block. + layer_num (int): Layer number. + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. + se_block(bool): Use se block in SE-ResNet50 net. Default: False. + Returns: + SequentialCell, the output layer. + + Examples: + >>> _make_layer(ResidualBlock, 3, 128, 256, 2) + """ + layers = [] + + resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se) + layers.append(resnet_block) + if se_block: + for _ in range(1, layer_num - 1): + resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se) + layers.append(resnet_block) + resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block) + layers.append(resnet_block) + else: + for _ in range(1, layer_num): + resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se) + layers.append(resnet_block) + return nn.SequentialCell(layers) + + +class ConvBnReplace(Replacement): + def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]: + bn_node: Node = matched.get(pattern.name()) + bn: nn.BatchNorm2d = bn_node.get_instance() + conv_p = pattern.get_inputs()[0] + conv_node: Node = matched.get(conv_p.name()) + conv: nn.Conv2d = conv_node.get_instance() + newconv = nn.Conv2dBnAct(conv.in_channels, + conv.out_channels, + conv.kernel_size, + conv.stride, + conv.pad_mode, + conv.padding, + conv.dilation, + conv.group, + conv.has_bias, + conv.weight_init, + conv.bias_init, + True, + bn.momentum, + bn.eps) + newconv_node = Node.create_call_cell(newconv, bn_node.get_targets(), conv_node.get_args(), + conv_node.get_kwargs(), "Conv2dBnAct") + return [newconv_node] + + +class ConvBnPattern(PatternEngine): + def __init__(self): + super().__init__([nn.Conv2d, nn.BatchNorm2d], ConvBnReplace()) + + +class CellBlock(nn.Cell): + """ + ResNet V1 residual block definition. + + Args: + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + use_se (bool): Enable SE-ResNet50 net. Default: False. + se_block(bool): Use se block in SE-ResNet50 net. Default: False. + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, stride=2) + """ + expansion = 4 + + def __init__(self, in_channel, out_channel, stride=1,): + super(CellBlock, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 1, stride=1) + self.bn1 = nn.BatchNorm2d(6, eps=1e-4, momentum=0.9, + gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1) + self.relu = nn.ReLU() + self.down_sample_layer = nn.SequentialCell([nn.Conv2d(in_channel, out_channel, 1)]) + + def construct(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + x = self.down_sample_layer(x) + out = out + x + return out + + +class ForNetWithSubTree(nn.Cell): + def __init__(self): + super(ForNetWithSubTree, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 1) + self.conv2 = nn.Conv2d(6, 16, 1) + self.relu = nn.ReLU() + self.relu1 = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.max_pool2d1 = nn.MaxPool2d(kernel_size=2, stride=2) + layers1 = [self.conv1, self.conv2, self.max_pool2d, self.relu] + self.layer1 = nn.SequentialCell(layers1) + + resnet_block1 = CellBlock(3, 6) + resnet_block2 = CellBlock(6, 16) + resnet_block3 = CellBlock(16, 32) + layers = [resnet_block1, resnet_block2, resnet_block3] + self.layer2 = nn.SequentialCell(layers) + + def construct(self, x): + x = self.conv1(x) + x = self.layer1(x) + x = self.relu(x) + x = self.layer2(x) + return x + + +def test_erase_subtree_node(): + """ + Feature: for parser and erase api. + Description: erase a node in subtree of `SymbolTree`. + Expectation: Success. + """ + net = ForNetWithSubTree() + stree = SymbolTree.create(net) + + for node in stree.nodes(): + if node.get_name() == "layer1": + subtree = TreeNodeHelper.get_sub_tree(node) + orig_node_num = len(subtree.get_handler()._nodes) + for n in subtree.nodes(): + if n.get_instance_type() == nn.MaxPool2d: + input_node = n.get_inputs()[0] + output_nodes = n.get_users() + for out_node in output_nodes: + out_node.set_arg_by_node(0, input_node) + subtree.erase_node(n) + break + assert len(subtree.get_handler()._nodes) == orig_node_num - 1 + break + + +def test_erase_subtree_node_01(): + """ + Feature: for parser and erase api. + Description: erase a node in subtree of `SymbolTree`. + Expectation: Success. + """ + net = ForNetWithSubTree() + stree = SymbolTree.create(net) + + for node in stree.nodes(): + if node.get_name() == "layer2": + subtree = TreeNodeHelper.get_sub_tree(node) + orig_node_num = len(subtree.get_handler()._nodes) + for n in subtree.nodes(): + if n.get_name() == "cell_list_1": + input_node = n.get_inputs()[0] + output_nodes = n.get_users() + for _nn in output_nodes: + _nn.set_arg_by_node(0, input_node) + subtree.erase_node(n) + assert len(subtree.get_handler()._nodes) == orig_node_num - 1 + break + break + + +def test_erase_subtree_node_02(): + """ + Feature: for parser and erase api. + Description: for parser and erase node in subtree of `SymbolTree`. + Expectation: Success. + """ + def _remove_bn(subtree): + for node in subtree.nodes(): + if node.get_name() == "bn1": + input_node = node.get_inputs()[0] + output_nodes = node.get_users() + for n in output_nodes: + n.set_arg_by_node(0, input_node) + subtree.erase_node(node) + break + + net = ForNetWithSubTree() + stree = SymbolTree.create(net) + for node in stree.nodes(): + if node.get_name() == "layer2": + subtree = TreeNodeHelper.get_sub_tree(node) + for n in subtree.nodes(): + if n.get_name() == "cell_list_1": + subtree1 = TreeNodeHelper.get_sub_tree(n) + _remove_bn(subtree1) + assert subtree1.get_node("bn1") is None + break + + +def test_insert_subtree_node(): + """ + Feature: for parser and insert api. + Description: Insert node into subtree in `Symboltree`. + Expectation: Success. + """ + def _insert_node(subtree): + for node in subtree.nodes(): + if node.get_name() == "bn1": + position = subtree.before(node) + new_conv = nn.Conv2d(16, 16, 3) + new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv', + args=[ScopedValue.create_naming_value('self_max_po')]) + subtree.insert(position, new_conv_node) + + net = ForNetWithSubTree() + stree = SymbolTree.create(net) + for node in stree.nodes(): + if node.get_name() == "layer2" and node.get_node_type() == NodeType.Tree: + subtree = TreeNodeHelper.get_sub_tree(node) + for n in subtree.nodes(): + if n.get_name() == "cell_list_1": + subtree1 = TreeNodeHelper.get_sub_tree(n) + orig_node_num = len(subtree1.get_handler()._nodes) + _insert_node(subtree1) + assert len(subtree1.get_handler()._nodes) == orig_node_num + 1 + + +def test_resnet_replace_121(): + """ + Feature: for parser and replace api. + Description: Replace one node by one nodes in subtree of `SymbolTree`.. + Expectation: Success. + """ + net = ForNetWithSubTree() + stree: SymbolTree = SymbolTree.create(net) + original_nodes_size = len(stree.get_handler()._nodes) + for node in stree.nodes(): + if node.get_name() == "layer1" and node.get_node_type() == NodeType.Tree: + subtree = TreeNodeHelper.get_sub_tree(node) + for n in subtree.nodes(): + if n.get_instance_type() == nn.Conv2d: + conv: nn.Conv2d = n.get_instance() + new_conv = Node.create_call_cell(nn.Conv2d(conv.in_channels, conv.out_channels, conv.kernel_size), + targets=n.get_targets(), args=n.get_args(), + kwargs=node.get_kwargs(), name="new_conv") + subtree.replace(n, [new_conv]) + break + assert len(stree.get_handler()._nodes) == original_nodes_size + + +def test_resnet_replace_12m(): + """ + Feature: for parser and replace api. + Description: Replace one node by multi-nodes in subtree of `SymbolTree`. + Expectation: Success. + """ + net = ForNetWithSubTree() + stree: SymbolTree = SymbolTree.create(net) + + for node in stree.nodes(): + if node.get_name() == "layer1" and node.get_node_type() == NodeType.Tree: + subtree = TreeNodeHelper.get_sub_tree(node) + original_nodes_size = len(subtree.get_handler()._nodes) + for n in subtree.nodes(): + if n.get_instance_type() == nn.Conv2d: + conv: nn.Conv2d = n.get_instance() + new_conv = Node.create_call_cell(nn.Conv2d(conv.in_channels, conv.out_channels, conv.kernel_size), + targets=["x"], args=n.get_args(), + kwargs=node.get_kwargs(), name="new_conv") + new_bn = Node.create_call_cell(nn.BatchNorm2d(conv.out_channels), + targets=n.get_targets(), args=[ScopedValue.create_naming_value("x")], + kwargs={}, name="new_bn") + subtree.replace(n, [new_conv, new_bn]) + break + assert len(subtree.get_handler()._nodes) == original_nodes_size + 1 + + +def test_node_fusion_in_subtree(): + """ + Feature: for parser and PatternEngine. + Description: Apply PatternEngine on nodes in `SymbolTree`.. + Expectation: Success. + """ + net = ForNetWithSubTree() + stree: SymbolTree = SymbolTree.create(net) + original_nodes_size = len(stree.get_handler()._nodes) + for node in stree.nodes(): + if node.get_name() == "layer1" and node.get_node_type() == NodeType.Tree: + subtree = TreeNodeHelper.get_sub_tree(node) + original_nodes_size = len(subtree.get_handler()._nodes) + for n in subtree.nodes(): + node_: Node = n + if node_.get_instance_type() == nn.Conv2d: + old_bn = node_.get_users()[0] + pos = subtree.after(node_) + conv: nn.Conv2d = node_.get_instance() + new_bn = Node.create_call_cell(nn.BatchNorm2d(conv.out_channels), targets=["x"], + args=[node_.get_targets()[0]], kwargs={}, name="new_bn") + subtree.insert(pos, new_bn) + old_bn.set_arg_by_node(0, new_bn) + break + assert len(subtree.get_handler()._nodes) == original_nodes_size + 1 + ConvBnPattern().apply(subtree) + assert len(subtree.get_handler()._nodes) == original_nodes_size + assert not subtree.get_node("conv1") + assert not subtree.get_node("new_bn")