!34567 fix bug in class_def_parser

Merge pull request !34567 from hangq/rewrite-pr
This commit is contained in:
i-robot 2022-05-19 01:20:10 +00:00 committed by Gitee
commit 77726dd5eb
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 96 additions and 4 deletions

View File

@ -120,6 +120,7 @@
"mindspore/tests/ut/python/rewrite/test_node.py" "syntax-error"
"mindspore/tests/ut/python/rewrite/test_node.py" "protected-access"
"mindspore/tests/ut/python/rewrite/test_symbol_tree.py" "len-as-condition"
"mindspore/tests/ut/python/rewrite/test_lenet.py" "protected-access"
"mindspore/tests/ut/python/test_log.py" "possibly-unused-variable"
"mindspore/tests/ut/python/test_log.py" "protected-access"
"mindspore/tests/ut/python/train/summary/test_summary_collector.py" "protected-access"

View File

@ -762,7 +762,7 @@ class Conv2dBnFoldQuantOneConv(Cell):
self.fake_quant_weight = quant_config.weight(ema=False,
channel_axis=channel_axis,
num_channels=out_channels)
self.freeze_bn = True
self.freeze_bn = False
self.bn_train = P.BatchNorm(is_training=True, epsilon=self.eps,
momentum=self.momentum, data_format=self.format)

View File

@ -168,7 +168,7 @@ class ClassDefParser(Parser):
if isinstance(body, ast.If):
ClassDefParser._remove_empty_ast_in_init_func(body.body)
ClassDefParser._remove_empty_ast_in_init_func(body.orelse)
if not body.body or not body.orelse:
if not body.body and not body.orelse:
body_index_to_be_deleted.append(body_index)
continue
if isinstance(body, ast.For):

View File

@ -23,7 +23,7 @@ from .symbol_tree import SymbolTree
from .node import TreeNode
from .parser_register import ParserRegister
from .parser import Parser
from .ast_transformers import FlattenRecursiveStmt, RemoveReturnOutOfIf
from .ast_transformers import FlattenRecursiveStmt
from .ast_helpers import AstModifier
from .ast_helpers import AstFinder
@ -71,7 +71,7 @@ class SymbolTreeBuilder:
Returns:
An instance of ast been optimized.
"""
transform_list = [FlattenRecursiveStmt(), RemoveReturnOutOfIf()]
transform_list = [FlattenRecursiveStmt()]
for transformer in transform_list:
ast_root = transformer.transform(ast_root)
return ast_root

View File

@ -0,0 +1,91 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LeNet."""
from collections import OrderedDict
import mindspore.nn as nn
from mindspore.common.initializer import Normal
from mindspore.rewrite import SymbolTree, PatternEngine, Replacement, PatternNode, Node
class LeNet5(nn.Cell):
def __init__(self, num_class=10, num_channel=1, include_top=True):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.include_top = include_top
if self.include_top:
self.flatten = nn.Flatten()
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
if not self.include_top:
return x
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
class ConvActReplace(Replacement):
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
conv_p = pattern.get_inputs()[0]
conv_node: Node = matched.get(conv_p.name())
conv: nn.Conv2d = conv_node.get_instance()
newconv = nn.Conv2dBnAct(conv.in_channels,
conv.out_channels,
conv.kernel_size,
conv.stride,
conv.pad_mode,
conv.padding,
conv.dilation,
conv.group,
conv.has_bias,
conv.weight_init,
conv.bias_init,
False,
activation="relu")
newconv_node = Node.create_call_cell(newconv, conv_node.get_targets(), conv_node.get_args(),
conv_node.get_kwargs(), "Conv2dBnAct")
return [newconv_node]
class ConvReLUPattern(PatternEngine):
def __init__(self):
super().__init__([nn.Conv2d, nn.ReLU], ConvActReplace())
def test_lenet():
"""
Feature: Test PatternEngine.
Description: Test PatternEngine on Lenet5.
Expectation: Success.
"""
net = LeNet5(10)
stree = SymbolTree.create(net)
original_nodes_size = len(stree.get_handler()._nodes)
ConvReLUPattern().apply(stree)
assert len(stree.get_handler()._nodes) == original_nodes_size - 2