support const-test if

This commit is contained in:
hangangqiang 2022-05-19 16:48:02 +08:00
parent 054e08d45d
commit 41e4d93ddd
8 changed files with 387 additions and 15 deletions

View File

@ -44,6 +44,7 @@
"mindspore/mindspore/python/mindspore/rewrite/parser_register.py" "protected-access"
"mindspore/mindspore/python/mindspore/rewrite/api/pattern_engine.py" "protected-access"
"mindspore/mindspore/python/mindspore/rewrite/symbol_tree.py" "inconsistent-return-statements"
"mindspore/mindspore/python/mindspore/rewrite/parsers/if_parser.py" "eval-used"
"mindspore/model_zoo/official/cv" "missing-docstring"
"mindspore/model_zoo/official/cv" "c-extension-no-member"
"mindspore/model_zoo/official/nlp/bert_thor/src/bert_model.py" "redefined-outer-name"
@ -121,6 +122,7 @@
"mindspore/tests/ut/python/rewrite/test_node.py" "protected-access"
"mindspore/tests/ut/python/rewrite/test_symbol_tree.py" "len-as-condition"
"mindspore/tests/ut/python/rewrite/test_lenet.py" "protected-access"
"mindspore/tests/ut/python/rewrite/test_if.py" "protected-access"
"mindspore/tests/ut/python/test_log.py" "possibly-unused-variable"
"mindspore/tests/ut/python/test_log.py" "protected-access"
"mindspore/tests/ut/python/train/summary/test_summary_collector.py" "protected-access"

View File

@ -21,6 +21,7 @@ from .parsers.class_def_parser import g_classdef_parser
from .parsers.function_def_parser import g_functiondef_parser
from .parsers.arguments_parser import g_arguments_parser
from .parsers.assign_parser import g_assign_parser
from .parsers.if_parser import g_if_parser
from .parsers.return_parser import g_return_parser
from .api.scoped_value import ScopedValue, ValueType
from .api.symbol_tree import SymbolTree

View File

@ -38,6 +38,15 @@ class AstModifier(ast.NodeTransformer):
if id(body) == id(to_erase):
ast_func.body.remove(body)
return True
if isinstance(body, ast.If):
for if_body in body.body:
if id(if_body) == id(to_erase):
body.body.remove(if_body)
return True
for else_body in body.orelse:
if id(else_body) == id(to_erase):
body.orelse.remove(else_body)
return True
return False
@staticmethod
@ -151,14 +160,34 @@ class AstModifier(ast.NodeTransformer):
ast.fix_missing_locations(ast_func)
return ast_assign
for index in range(0, len(ast_func.body)):
if id(ast_func.body[index]) == id(index_ast):
body = ast_func.body[index]
if id(body) == id(index_ast):
if insert_before:
ast_func.body.insert(index, ast_assign)
else:
ast_func.body.insert(index + 1, ast_assign)
ast.fix_missing_locations(ast_func)
return ast_assign
raise RuntimeError("index_ast is not contained in ast_func")
if isinstance(body, ast.If):
for if_index in range(0, len(body.body)):
if_body = body.body[if_index]
if id(if_body) == id(index_ast):
if insert_before:
body.body.insert(if_index, ast_assign)
else:
body.body.insert(if_index + 1, ast_assign)
ast.fix_missing_locations(body)
return ast_assign
for if_index in range(0, len(body.orelse)):
else_body = body.orelse[if_index]
if id(else_body) == id(index_ast):
if insert_before:
body.orelse.insert(if_index, ast_assign)
else:
body.orelse.insert(if_index + 1, ast_assign)
ast.fix_missing_locations(body)
return ast_assign
raise RuntimeError("insert position is not contained in ast_func")
@staticmethod
def append_global_vars_expr_to_init(init_func: ast.FunctionDef, targets: [ScopedValue],

View File

@ -30,6 +30,7 @@ from ..api.scoped_value import ScopedValue, ValueType
from ..symbol_tree_builder import SymbolTreeBuilder
from ..ast_helpers import AstReplacer, AstModifier
from ..common.event import Event
from ..namespace import is_subtree
class AssignParser(Parser):
@ -190,9 +191,6 @@ class AssignParser(Parser):
results[keyword.arg] = AssignParser._create_scopedvalue(keyword.value)
return results
def _is_subtree_cell(self, cell: Cell) -> bool:
return not type(cell).__name__ in self._cell_namespce
@staticmethod
def _find_op_and_type(func_scope, func_name, stree: SymbolTree):
"""
@ -331,7 +329,7 @@ class AssignParser(Parser):
if isinstance(op, Primitive):
return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, call_kwargs, func_name)
if isinstance(op, Cell):
is_sub_tree = self._is_subtree_cell(op)
is_sub_tree = is_subtree(type(op).__name__)
if is_sub_tree:
stb = SymbolTreeBuilder(op)
new_stree = stb.build()

View File

@ -0,0 +1,64 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse ast.If in construct function to node of SymbolTree."""
import ast
import astunparse
from ..symbol_tree import SymbolTree
from ..parser import Parser
from ..parser_register import ParserRegister, reg_parser
class IfParser(Parser):
"""Parse ast.If in construct function to node of SymbolTree."""
def target(self):
"""Parse target type"""
return ast.If
def process(self, stree: SymbolTree, node: ast.If):
"""
Parse ast.If and create a node in symbol tree.
Args:
stree ([SymbolTree]): Symbol Tree under parsing.
node ([ast.If]): An ast.If node.
Raises:
NotImplementedError: If test of ast.If can not be eval.
"""
test_code = astunparse.unparse(node.test)
test_code = test_code.replace("self", "stree.get_origin_network()")
bodies = None
try:
test_value = eval(test_code)
if test_value:
bodies = node.body
else:
bodies = node.orelse
except Exception:
raise NotImplementedError("Only support ast.If whose test can be eval, got:", test_code)
for body in bodies:
parser: Parser = ParserRegister.instance().get_parser(type(body))
if parser is None:
stree.append_python_node(node, body)
else:
parser.process(stree, body)
g_if_parser = reg_parser(IfParser())

View File

@ -131,30 +131,41 @@ class SymbolTree(Observer, Observable):
Returns:
An instance of Node represents root of input nodes.
"""
consumers: [ScopedValue] = []
consumers: {ScopedValue: [Node]} = {}
target_dict: {ScopedValue: Node} = {}
for node in nodes:
consumers.extend(node.get_args())
for arg in node.get_args():
if consumers.get(arg):
consumers[arg].append(node)
else:
consumers[arg] = [node]
for _, arg in node.get_kwargs():
consumers.append(arg)
if consumers.get(arg):
consumers[arg].append(node)
else:
consumers[arg] = [node]
for target in node.get_targets():
if target_dict.get(target) is not None:
raise RuntimeError("Target of node duplicated")
raise RuntimeError(f"Target({target}) of node duplicated")
target_dict[target] = node
# find root node
root = None
for node in nodes:
used = 0
for target in node.get_targets():
if target in consumers:
used += 1
break
cur_consumers = consumers.get(target)
if not cur_consumers:
continue
for cur_consumer in cur_consumers:
if id(cur_consumer) != id(node):
used += 1
break
if used == 0:
if root is not None:
raise RuntimeError("Replacement should only has one root")
raise RuntimeError("Replacement should only has one root, found multi-root")
root = node
if root is None:
raise RuntimeError("No root node found in replacement nodes")
raise RuntimeError("Replacement should only has one root, found no root")
# link node's input
for node in nodes:
inputs = []

View File

@ -0,0 +1,248 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LeNet."""
from collections import OrderedDict
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore.rewrite import SymbolTree, PatternEngine, Replacement, PatternNode, Node, ScopedValue
class IfNet(nn.Cell):
def __init__(self, use_se=False, res_base=False):
super(IfNet, self).__init__()
self.use_se = use_se
self.res_base = res_base
self.se_block = False
if self.use_se:
self.se_block = True
if self.use_se:
self.conv1_0 = nn.Conv2d(3, 32, 3, stride=2, padding=0, pad_mode='same')
self.bn1_0 = nn.BatchNorm2d(32)
self.conv1_1 = nn.Conv2d(32, 32, 3, stride=1, padding=0, pad_mode='same')
self.bn1_1 = nn.BatchNorm2d(32)
self.conv1_2 = nn.Conv2d(32, 64, 3, stride=1, padding=0, pad_mode='same')
else:
self.conv1 = nn.Conv2d(3, 64, 7, stride=2, padding=0, pad_mode='same')
self.bn1 = nn.BatchNorm2d(64)
self.relu = P.ReLU()
if self.res_base:
self.pad = nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)))
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="valid")
else:
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = nn.Dense(in_channels=2048, out_channels=10, has_bias=True, bias_init=0)
def construct(self, x):
if self.use_se:
x = self.conv1_0(x)
x = self.bn1_0(x)
x = self.relu(x)
x = self.conv1_1(x)
x = self.bn1_1(x)
x = self.relu(x)
x = self.conv1_2(x)
else:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
if self.res_base:
x = self.pad(x)
c1 = self.maxpool(x)
out = self.mean(c1, (2, 3))
out = self.flatten(out)
out = self.end_point(out)
return out
class ConvBnReplace(Replacement):
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
bn_node: Node = matched.get(pattern.name())
bn: nn.BatchNorm2d = bn_node.get_instance()
conv_p = pattern.get_inputs()[0]
conv_node: Node = matched.get(conv_p.name())
conv: nn.Conv2d = conv_node.get_instance()
newconv = nn.Conv2dBnAct(conv.in_channels,
conv.out_channels,
conv.kernel_size,
conv.stride,
conv.pad_mode,
conv.padding,
conv.dilation,
conv.group,
conv.has_bias,
conv.weight_init,
conv.bias_init,
True,
bn.momentum,
bn.eps)
newconv_node = Node.create_call_cell(newconv, bn_node.get_targets(), conv_node.get_args(),
conv_node.get_kwargs(), "Conv2dBnAct")
return [newconv_node]
class ConvBnPattern(PatternEngine):
def __init__(self):
super().__init__([nn.Conv2d, nn.BatchNorm2d], ConvBnReplace())
def test_resnet_erase_in_if():
"""
Feature: erase_node api and if_parser
Description: erase a node in ast.If.
Expectation: Success.
"""
net = IfNet()
stree: SymbolTree = SymbolTree.create(net)
original_nodes_size = len(stree.get_handler()._nodes)
for node in stree.nodes():
node_: Node = node
if node_.get_instance_type() == nn.Conv2d:
input_ = node_.get_inputs()[0]
output = node_.get_users()[0]
output.set_arg_by_node(0, input_)
stree.erase_node(node)
break
assert len(stree.get_handler()._nodes) == original_nodes_size - 1
def test_resnet_insert_in_if():
"""
Feature: insert api and if_parser
Description: insert a node into ast.If.
Expectation: Success.
"""
net = IfNet()
stree: SymbolTree = SymbolTree.create(net)
original_nodes_size = len(stree.get_handler()._nodes)
for node in stree.nodes():
node_: Node = node
if node_.get_instance_type() == nn.Conv2d:
pos = stree.after(node_)
conv: nn.Conv2d = node_.get_instance()
new_bn = Node.create_call_cell(nn.BatchNorm2d(conv.out_channels), targets=["x"],
args=[ScopedValue.create_naming_value("x")], kwargs={}, name="new_bn")
stree.insert(pos, new_bn)
break
assert len(stree.get_handler()._nodes) == original_nodes_size + 1
def test_resnet_replace_121_in_if():
"""
Feature: replace api and if_parser
Description: Replace one node by one nodes in ast.If.
Expectation: Success.
"""
net = IfNet()
stree: SymbolTree = SymbolTree.create(net)
original_nodes_size = len(stree.get_handler()._nodes)
for node in stree.nodes():
node_: Node = node
if node_.get_instance_type() == nn.Conv2d:
conv: nn.Conv2d = node_.get_instance()
new_conv = Node.create_call_cell(nn.Conv2d(conv.in_channels, conv.out_channels, conv.kernel_size),
targets=node_.get_targets(), args=node_.get_args(),
kwargs=node.get_kwargs(), name="new_conv")
stree.replace(node_, [new_conv])
break
assert len(stree.get_handler()._nodes) == original_nodes_size
def test_resnet_replace_12m_in_if():
"""
Feature: replace api and if_parser
Description: Replace one node by multi-nodes in ast.If.
Expectation: Success.
"""
net = IfNet()
stree: SymbolTree = SymbolTree.create(net)
original_nodes_size = len(stree.get_handler()._nodes)
for node in stree.nodes():
node_: Node = node
if node_.get_instance_type() == nn.Conv2d:
conv: nn.Conv2d = node_.get_instance()
new_conv = Node.create_call_cell(nn.Conv2d(conv.in_channels, conv.out_channels, conv.kernel_size),
targets=["x"], args=node_.get_args(),
kwargs=node.get_kwargs(), name="new_conv")
new_bn = Node.create_call_cell(nn.BatchNorm2d(conv.out_channels),
targets=node_.get_targets(), args=[ScopedValue.create_naming_value("x")],
kwargs={}, name="new_bn")
stree.replace(node_, [new_conv, new_bn])
break
assert len(stree.get_handler()._nodes) == original_nodes_size + 1
def test_resnet_fusion_in_if():
"""
Feature: PatternEngine and if_parser
Description: Apply PatternEngine on nodes in ast.If.
Expectation: Success.
"""
net = IfNet()
stree: SymbolTree = SymbolTree.create(net)
original_nodes_size = len(stree.get_handler()._nodes)
for node in stree.nodes():
node_: Node = node
if node_.get_instance_type() == nn.Conv2d:
old_bn = node_.get_users()[0]
pos = stree.after(node_)
conv: nn.Conv2d = node_.get_instance()
new_bn = Node.create_call_cell(nn.BatchNorm2d(conv.out_channels), targets=["x"],
args=[node_.get_targets()[0]], kwargs={}, name="new_bn")
stree.insert(pos, new_bn)
old_bn.set_arg_by_node(0, new_bn)
break
assert len(stree.get_handler()._nodes) == original_nodes_size + 1
ConvBnPattern().apply(stree)
print(stree.get_code())
assert len(stree.get_handler()._nodes) == original_nodes_size
assert not stree.get_node("conv1")
assert not stree.get_node("new_bn")
assert stree.get_node("bn1")
def test_resnet_fusion_cross_if():
"""
Feature: PatternEngine and if_parser
Description: Apply PatternEngine on nodes cross ast.If.
Expectation: Success.
"""
net = IfNet()
stree: SymbolTree = SymbolTree.create(net)
original_nodes_size = len(stree.get_handler()._nodes)
for node in stree.nodes():
node_: Node = node
if node_.get_instance_type() == nn.Conv2d:
pos = stree.after(node_)
conv: nn.Conv2d = node_.get_instance()
new_bn = Node.create_call_cell(nn.BatchNorm2d(conv.out_channels), targets=["x"],
args=[ScopedValue.create_naming_value("x")], kwargs={}, name="new_bn")
stree.insert(pos, new_bn)
break
assert len(stree.get_handler()._nodes) == original_nodes_size + 1
ConvBnPattern().apply(stree)
print(stree.get_code())
assert len(stree.get_handler()._nodes) == original_nodes_size
assert not stree.get_node("conv1")
assert stree.get_node("new_bn")
assert not stree.get_node("bn1")

View File

@ -357,6 +357,25 @@ def test_replace_one_to_one():
assert new_conv_node.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
def test_replace_one_to_one_with_same_arg_and_target():
"""
Feature: Python api replace of SymbolTree of Rewrite.
Description: Call replace to replace an origin node to a new node whose arg and target are same.
Expectation: Success.
"""
stree, _, relu1, _ = create_symbol_tree()
construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
assert len(construct_ast.body) == 6
assert get_symbol_tree_nodes_count(stree) == 7
new_conv = Conv2d(16, 16, 5)
new_conv_node = NodeApi.create_call_cell(new_conv, [ScopedValue.create_naming_value("new_conv")],
[ScopedValue.create_naming_value("new_conv")]).get_handler()
stree.replace(relu1, [new_conv_node])
assert get_symbol_tree_nodes_count(stree) == 7
assert stree.get_node("new_conv")
def test_replace_one_to_multi():
"""
Feature: Python api replace of SymbolTree of Rewrite.