!34805 support const iterator For

Merge pull request !34805 from 于振华/rewrite_for_unfold_0518
This commit is contained in:
i-robot 2022-05-28 08:45:22 +00:00 committed by Gitee
commit d70fe9e7da
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 438 additions and 2 deletions

View File

@ -43,6 +43,7 @@
"mindspore/mindspore/python/mindspore/rewrite/symbol_tree.py" "protected-access"
"mindspore/mindspore/python/mindspore/rewrite/parser_register.py" "protected-access"
"mindspore/mindspore/python/mindspore/rewrite/api/pattern_engine.py" "protected-access"
"mindspore/mindspore/python/mindspore/rewrite/parsers/for_parser.py" "eval-used"
"mindspore/mindspore/python/mindspore/rewrite/symbol_tree.py" "inconsistent-return-statements"
"mindspore/mindspore/python/mindspore/rewrite/parsers/if_parser.py" "eval-used"
"mindspore/model_zoo/official/cv" "missing-docstring"
@ -120,6 +121,7 @@
"mindspore/tests/ut/python/rewrite/test_flatten_recursive_stmt.py" "consider-using-ternary"
"mindspore/tests/ut/python/rewrite/test_node.py" "syntax-error"
"mindspore/tests/ut/python/rewrite/test_node.py" "protected-access"
"mindspore/tests/ut/python/rewrite/test_for.py" "protected-access"
"mindspore/tests/ut/python/rewrite/test_symbol_tree.py" "len-as-condition"
"mindspore/tests/ut/python/rewrite/test_lenet.py" "protected-access"
"mindspore/tests/ut/python/rewrite/test_if.py" "protected-access"

View File

@ -23,6 +23,7 @@ from .parsers.arguments_parser import g_arguments_parser
from .parsers.assign_parser import g_assign_parser
from .parsers.if_parser import g_if_parser
from .parsers.return_parser import g_return_parser
from .parsers.for_parser import g_for_parser
from .api.scoped_value import ScopedValue, ValueType
from .api.symbol_tree import SymbolTree
from .api.node import Node

View File

@ -50,6 +50,12 @@ class AstModifier(ast.NodeTransformer):
return True
return False
@staticmethod
def erase_func_from_class_by_name(ast_class: ast.ClassDef, func_name: str):
for body in ast_class.body:
if isinstance(body, ast.FunctionDef) and body.name == func_name:
ast_class.body.remove(body)
@staticmethod
def insert_sub_ast(ast_father: ast.AST, ast_son: ast.AST, index_ast: Optional[ast.AST] = None,
insert_before=True) -> ast.AST:

View File

@ -0,0 +1,93 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Parse ast.For node """
import ast
import astunparse
from mindspore.rewrite.api.scoped_value import ScopedValue, ValueType
from mindspore.rewrite.ast_helpers.ast_modifier import AstModifier
from mindspore import log as logger
from ..symbol_tree import SymbolTree
from ..parser import Parser
from ..parser_register import reg_parser
from ..common.event import Event
EVAL_WHITE_LIST = ("self.", "range(", "zip(", "enumerate(", "reversed(")
class ForParser(Parser):
""" Class that implements parsing ast.For nodes """
@staticmethod
def modify_init_ast(stree, i, obj, iter_var_name):
"""Modify the ast node in init function."""
target = f"{iter_var_name.strip()}_{str(i)}"
stree.add_global_vars(target, obj)
stree.get_origin_network().insert_child_to_cell(target, obj)
AstModifier.insert_assign_to_function(stree.get_init_func_ast(),
targets=[ScopedValue(ValueType.NamingValue, "self", target)],
expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"),
args=[ScopedValue(ValueType.StringValue, "", target)])
@staticmethod
def modify_construct_ast(stree, ast_node, old_name, new_name):
"""Modify the ast node in construct function."""
node_str: str = astunparse.unparse(ast_node)
node_str = node_str.replace(old_name, new_name)
module_node = ast.parse(node_str)
new_node = module_node.body[0]
return new_node
def target(self):
return ast.For
def process(self, stree: SymbolTree, node: ast.For):
""" Process ast.For node """
if isinstance(node.target, ast.Name):
targets = node.target.id
iter_code = astunparse.unparse(node.iter)
if not iter_code.startswith(EVAL_WHITE_LIST):
logger.warning(f"Illegal iteration condition for For node, it must start with{EVAL_WHITE_LIST}")
return
if iter_code.startswith("self"):
iter_code = iter_code.replace("self", "stree.get_origin_network()")
try:
iter_obj = eval(iter_code)
except Exception as e:
error_info = f"When eval '{iter_code}' by using JIT Fallback feature, an error occurred: {str(e)}"
logger.error(error_info)
raise e
iter_var_name = iter_code.split(".")[-1]
index = stree.get_ast_root().body.index(node) + 1
if isinstance(iter_obj, list):
for i, obj in enumerate(iter_obj):
ForParser.modify_init_ast(stree, i, obj, iter_var_name)
for body in node.body:
new_func_name = f"self.{iter_var_name.strip()}_{str(i)}".strip()
new_node = ForParser.modify_construct_ast(stree, body, targets, new_func_name)
stree.get_ast_root().body.insert(index, new_node)
index += 1
if stree.get_ori_cls_name() == "SequentialCell":
stree.on_change(Event.CodeChangeEvent)
elif isinstance(iter_obj, range):
raise NotImplementedError("range not support")
elif isinstance(iter_obj, zip):
raise NotImplementedError("zip not support")
elif isinstance(iter_obj, enumerate):
raise NotImplementedError("enumerate not support")
else:
raise ValueError("not supported type: ", iter_obj)
g_for_parser = reg_parser(ForParser())

View File

@ -44,6 +44,9 @@ class FunctionDefParser(Parser):
else:
parser.process(stree, body)
for body in node.body:
if isinstance(body, ast.For):
node.body.remove(body)
if hasattr(node, "decorator_list"):
stree.try_append_python_node(node, node.decorator_list)
if hasattr(node, "returns"):

View File

@ -766,8 +766,6 @@ class SymbolTree(Observer, Observable):
value.isolate()
break
self._topo_mgr.on_erase_node(node)
if self._node_visitor:
self._node_visitor.remove_node(node)
return node
def replace(self, old_node: Node, new_nodes: [Node]) -> Node:
@ -932,6 +930,9 @@ class SymbolTree(Observer, Observable):
body = self._module_ast.body[i]
if not isinstance(body, (ast.Import, ast.ImportFrom)):
continue
if isinstance(body, ast.ImportFrom) and body.module == "cell":
self._module_ast.body.remove(body)
continue
for alias in body.names:
name = alias.asname if alias.asname else alias.name
if not str_checker.check(name):

View File

@ -54,6 +54,8 @@ class SymbolTreeBuilder:
3. Use merged ast.Module as module of main-network and sub-network.
"""
if sub_stree.get_ori_cls_name() == "SequentialCell":
SymbolTreeBuilder._erase_unused_func_of_sequentialcell(sub_stree.get_class_ast())
father_mod = main_tree.get_module_ast()
sub_mod = sub_stree.get_module_ast()
SymbolTreeBuilder._merge_import_of_module(father_mod, sub_mod)
@ -122,6 +124,12 @@ class SymbolTreeBuilder:
for clazz in classes_in_sub:
AstModifier.insert_class_into_module(main_mod, clazz, first_class, True)
@staticmethod
def _erase_unused_func_of_sequentialcell(ast_class: ast.ClassDef):
func_names = ("__getitem__", "__setitem__", "__delitem__", "__len__", "append")
for name in func_names:
AstModifier.erase_func_from_class_by_name(ast_class, name)
def _merge_module_of_subtrees(self):
"""
Merge ast.Module of all sub-networks into ast.Module of main-network.

View File

@ -0,0 +1,322 @@
from collections import OrderedDict
from mindspore import nn
from mindspore.rewrite import SymbolTree, PatternEngine, Replacement, PatternNode, Node, ScopedValue
from mindspore.rewrite.api.tree_node_helper import TreeNodeHelper
from mindspore.rewrite.api.node_type import NodeType
def make_layer(block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False):
"""
Make stage network of ResNet.
Args:
block (Cell): Resnet block.
layer_num (int): Layer number.
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
se_block(bool): Use se block in SE-ResNet50 net. Default: False.
Returns:
SequentialCell, the output layer.
Examples:
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
"""
layers = []
resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se)
layers.append(resnet_block)
if se_block:
for _ in range(1, layer_num - 1):
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
layers.append(resnet_block)
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block)
layers.append(resnet_block)
else:
for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
layers.append(resnet_block)
return nn.SequentialCell(layers)
class ConvBnReplace(Replacement):
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
bn_node: Node = matched.get(pattern.name())
bn: nn.BatchNorm2d = bn_node.get_instance()
conv_p = pattern.get_inputs()[0]
conv_node: Node = matched.get(conv_p.name())
conv: nn.Conv2d = conv_node.get_instance()
newconv = nn.Conv2dBnAct(conv.in_channels,
conv.out_channels,
conv.kernel_size,
conv.stride,
conv.pad_mode,
conv.padding,
conv.dilation,
conv.group,
conv.has_bias,
conv.weight_init,
conv.bias_init,
True,
bn.momentum,
bn.eps)
newconv_node = Node.create_call_cell(newconv, bn_node.get_targets(), conv_node.get_args(),
conv_node.get_kwargs(), "Conv2dBnAct")
return [newconv_node]
class ConvBnPattern(PatternEngine):
def __init__(self):
super().__init__([nn.Conv2d, nn.BatchNorm2d], ConvBnReplace())
class CellBlock(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
use_se (bool): Enable SE-ResNet50 net. Default: False.
se_block(bool): Use se block in SE-ResNet50 net. Default: False.
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlock(3, 256, stride=2)
"""
expansion = 4
def __init__(self, in_channel, out_channel, stride=1,):
super(CellBlock, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 1, stride=1)
self.bn1 = nn.BatchNorm2d(6, eps=1e-4, momentum=0.9,
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
self.relu = nn.ReLU()
self.down_sample_layer = nn.SequentialCell([nn.Conv2d(in_channel, out_channel, 1)])
def construct(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
x = self.down_sample_layer(x)
out = out + x
return out
class ForNetWithSubTree(nn.Cell):
def __init__(self):
super(ForNetWithSubTree, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 1)
self.conv2 = nn.Conv2d(6, 16, 1)
self.relu = nn.ReLU()
self.relu1 = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.max_pool2d1 = nn.MaxPool2d(kernel_size=2, stride=2)
layers1 = [self.conv1, self.conv2, self.max_pool2d, self.relu]
self.layer1 = nn.SequentialCell(layers1)
resnet_block1 = CellBlock(3, 6)
resnet_block2 = CellBlock(6, 16)
resnet_block3 = CellBlock(16, 32)
layers = [resnet_block1, resnet_block2, resnet_block3]
self.layer2 = nn.SequentialCell(layers)
def construct(self, x):
x = self.conv1(x)
x = self.layer1(x)
x = self.relu(x)
x = self.layer2(x)
return x
def test_erase_subtree_node():
"""
Feature: for parser and erase api.
Description: erase a node in subtree of `SymbolTree`.
Expectation: Success.
"""
net = ForNetWithSubTree()
stree = SymbolTree.create(net)
for node in stree.nodes():
if node.get_name() == "layer1":
subtree = TreeNodeHelper.get_sub_tree(node)
orig_node_num = len(subtree.get_handler()._nodes)
for n in subtree.nodes():
if n.get_instance_type() == nn.MaxPool2d:
input_node = n.get_inputs()[0]
output_nodes = n.get_users()
for out_node in output_nodes:
out_node.set_arg_by_node(0, input_node)
subtree.erase_node(n)
break
assert len(subtree.get_handler()._nodes) == orig_node_num - 1
break
def test_erase_subtree_node_01():
"""
Feature: for parser and erase api.
Description: erase a node in subtree of `SymbolTree`.
Expectation: Success.
"""
net = ForNetWithSubTree()
stree = SymbolTree.create(net)
for node in stree.nodes():
if node.get_name() == "layer2":
subtree = TreeNodeHelper.get_sub_tree(node)
orig_node_num = len(subtree.get_handler()._nodes)
for n in subtree.nodes():
if n.get_name() == "cell_list_1":
input_node = n.get_inputs()[0]
output_nodes = n.get_users()
for _nn in output_nodes:
_nn.set_arg_by_node(0, input_node)
subtree.erase_node(n)
assert len(subtree.get_handler()._nodes) == orig_node_num - 1
break
break
def test_erase_subtree_node_02():
"""
Feature: for parser and erase api.
Description: for parser and erase node in subtree of `SymbolTree`.
Expectation: Success.
"""
def _remove_bn(subtree):
for node in subtree.nodes():
if node.get_name() == "bn1":
input_node = node.get_inputs()[0]
output_nodes = node.get_users()
for n in output_nodes:
n.set_arg_by_node(0, input_node)
subtree.erase_node(node)
break
net = ForNetWithSubTree()
stree = SymbolTree.create(net)
for node in stree.nodes():
if node.get_name() == "layer2":
subtree = TreeNodeHelper.get_sub_tree(node)
for n in subtree.nodes():
if n.get_name() == "cell_list_1":
subtree1 = TreeNodeHelper.get_sub_tree(n)
_remove_bn(subtree1)
assert subtree1.get_node("bn1") is None
break
def test_insert_subtree_node():
"""
Feature: for parser and insert api.
Description: Insert node into subtree in `Symboltree`.
Expectation: Success.
"""
def _insert_node(subtree):
for node in subtree.nodes():
if node.get_name() == "bn1":
position = subtree.before(node)
new_conv = nn.Conv2d(16, 16, 3)
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
args=[ScopedValue.create_naming_value('self_max_po')])
subtree.insert(position, new_conv_node)
net = ForNetWithSubTree()
stree = SymbolTree.create(net)
for node in stree.nodes():
if node.get_name() == "layer2" and node.get_node_type() == NodeType.Tree:
subtree = TreeNodeHelper.get_sub_tree(node)
for n in subtree.nodes():
if n.get_name() == "cell_list_1":
subtree1 = TreeNodeHelper.get_sub_tree(n)
orig_node_num = len(subtree1.get_handler()._nodes)
_insert_node(subtree1)
assert len(subtree1.get_handler()._nodes) == orig_node_num + 1
def test_resnet_replace_121():
"""
Feature: for parser and replace api.
Description: Replace one node by one nodes in subtree of `SymbolTree`..
Expectation: Success.
"""
net = ForNetWithSubTree()
stree: SymbolTree = SymbolTree.create(net)
original_nodes_size = len(stree.get_handler()._nodes)
for node in stree.nodes():
if node.get_name() == "layer1" and node.get_node_type() == NodeType.Tree:
subtree = TreeNodeHelper.get_sub_tree(node)
for n in subtree.nodes():
if n.get_instance_type() == nn.Conv2d:
conv: nn.Conv2d = n.get_instance()
new_conv = Node.create_call_cell(nn.Conv2d(conv.in_channels, conv.out_channels, conv.kernel_size),
targets=n.get_targets(), args=n.get_args(),
kwargs=node.get_kwargs(), name="new_conv")
subtree.replace(n, [new_conv])
break
assert len(stree.get_handler()._nodes) == original_nodes_size
def test_resnet_replace_12m():
"""
Feature: for parser and replace api.
Description: Replace one node by multi-nodes in subtree of `SymbolTree`.
Expectation: Success.
"""
net = ForNetWithSubTree()
stree: SymbolTree = SymbolTree.create(net)
for node in stree.nodes():
if node.get_name() == "layer1" and node.get_node_type() == NodeType.Tree:
subtree = TreeNodeHelper.get_sub_tree(node)
original_nodes_size = len(subtree.get_handler()._nodes)
for n in subtree.nodes():
if n.get_instance_type() == nn.Conv2d:
conv: nn.Conv2d = n.get_instance()
new_conv = Node.create_call_cell(nn.Conv2d(conv.in_channels, conv.out_channels, conv.kernel_size),
targets=["x"], args=n.get_args(),
kwargs=node.get_kwargs(), name="new_conv")
new_bn = Node.create_call_cell(nn.BatchNorm2d(conv.out_channels),
targets=n.get_targets(), args=[ScopedValue.create_naming_value("x")],
kwargs={}, name="new_bn")
subtree.replace(n, [new_conv, new_bn])
break
assert len(subtree.get_handler()._nodes) == original_nodes_size + 1
def test_node_fusion_in_subtree():
"""
Feature: for parser and PatternEngine.
Description: Apply PatternEngine on nodes in `SymbolTree`..
Expectation: Success.
"""
net = ForNetWithSubTree()
stree: SymbolTree = SymbolTree.create(net)
original_nodes_size = len(stree.get_handler()._nodes)
for node in stree.nodes():
if node.get_name() == "layer1" and node.get_node_type() == NodeType.Tree:
subtree = TreeNodeHelper.get_sub_tree(node)
original_nodes_size = len(subtree.get_handler()._nodes)
for n in subtree.nodes():
node_: Node = n
if node_.get_instance_type() == nn.Conv2d:
old_bn = node_.get_users()[0]
pos = subtree.after(node_)
conv: nn.Conv2d = node_.get_instance()
new_bn = Node.create_call_cell(nn.BatchNorm2d(conv.out_channels), targets=["x"],
args=[node_.get_targets()[0]], kwargs={}, name="new_bn")
subtree.insert(pos, new_bn)
old_bn.set_arg_by_node(0, new_bn)
break
assert len(subtree.get_handler()._nodes) == original_nodes_size + 1
ConvBnPattern().apply(subtree)
assert len(subtree.get_handler()._nodes) == original_nodes_size
assert not subtree.get_node("conv1")
assert not subtree.get_node("new_bn")