support const-test if
This commit is contained in:
parent
054e08d45d
commit
41e4d93ddd
|
@ -44,6 +44,7 @@
|
|||
"mindspore/mindspore/python/mindspore/rewrite/parser_register.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/rewrite/api/pattern_engine.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/rewrite/symbol_tree.py" "inconsistent-return-statements"
|
||||
"mindspore/mindspore/python/mindspore/rewrite/parsers/if_parser.py" "eval-used"
|
||||
"mindspore/model_zoo/official/cv" "missing-docstring"
|
||||
"mindspore/model_zoo/official/cv" "c-extension-no-member"
|
||||
"mindspore/model_zoo/official/nlp/bert_thor/src/bert_model.py" "redefined-outer-name"
|
||||
|
@ -121,6 +122,7 @@
|
|||
"mindspore/tests/ut/python/rewrite/test_node.py" "protected-access"
|
||||
"mindspore/tests/ut/python/rewrite/test_symbol_tree.py" "len-as-condition"
|
||||
"mindspore/tests/ut/python/rewrite/test_lenet.py" "protected-access"
|
||||
"mindspore/tests/ut/python/rewrite/test_if.py" "protected-access"
|
||||
"mindspore/tests/ut/python/test_log.py" "possibly-unused-variable"
|
||||
"mindspore/tests/ut/python/test_log.py" "protected-access"
|
||||
"mindspore/tests/ut/python/train/summary/test_summary_collector.py" "protected-access"
|
||||
|
|
|
@ -21,6 +21,7 @@ from .parsers.class_def_parser import g_classdef_parser
|
|||
from .parsers.function_def_parser import g_functiondef_parser
|
||||
from .parsers.arguments_parser import g_arguments_parser
|
||||
from .parsers.assign_parser import g_assign_parser
|
||||
from .parsers.if_parser import g_if_parser
|
||||
from .parsers.return_parser import g_return_parser
|
||||
from .api.scoped_value import ScopedValue, ValueType
|
||||
from .api.symbol_tree import SymbolTree
|
||||
|
|
|
@ -38,6 +38,15 @@ class AstModifier(ast.NodeTransformer):
|
|||
if id(body) == id(to_erase):
|
||||
ast_func.body.remove(body)
|
||||
return True
|
||||
if isinstance(body, ast.If):
|
||||
for if_body in body.body:
|
||||
if id(if_body) == id(to_erase):
|
||||
body.body.remove(if_body)
|
||||
return True
|
||||
for else_body in body.orelse:
|
||||
if id(else_body) == id(to_erase):
|
||||
body.orelse.remove(else_body)
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
|
@ -151,14 +160,34 @@ class AstModifier(ast.NodeTransformer):
|
|||
ast.fix_missing_locations(ast_func)
|
||||
return ast_assign
|
||||
for index in range(0, len(ast_func.body)):
|
||||
if id(ast_func.body[index]) == id(index_ast):
|
||||
body = ast_func.body[index]
|
||||
if id(body) == id(index_ast):
|
||||
if insert_before:
|
||||
ast_func.body.insert(index, ast_assign)
|
||||
else:
|
||||
ast_func.body.insert(index + 1, ast_assign)
|
||||
ast.fix_missing_locations(ast_func)
|
||||
return ast_assign
|
||||
raise RuntimeError("index_ast is not contained in ast_func")
|
||||
if isinstance(body, ast.If):
|
||||
for if_index in range(0, len(body.body)):
|
||||
if_body = body.body[if_index]
|
||||
if id(if_body) == id(index_ast):
|
||||
if insert_before:
|
||||
body.body.insert(if_index, ast_assign)
|
||||
else:
|
||||
body.body.insert(if_index + 1, ast_assign)
|
||||
ast.fix_missing_locations(body)
|
||||
return ast_assign
|
||||
for if_index in range(0, len(body.orelse)):
|
||||
else_body = body.orelse[if_index]
|
||||
if id(else_body) == id(index_ast):
|
||||
if insert_before:
|
||||
body.orelse.insert(if_index, ast_assign)
|
||||
else:
|
||||
body.orelse.insert(if_index + 1, ast_assign)
|
||||
ast.fix_missing_locations(body)
|
||||
return ast_assign
|
||||
raise RuntimeError("insert position is not contained in ast_func")
|
||||
|
||||
@staticmethod
|
||||
def append_global_vars_expr_to_init(init_func: ast.FunctionDef, targets: [ScopedValue],
|
||||
|
|
|
@ -30,6 +30,7 @@ from ..api.scoped_value import ScopedValue, ValueType
|
|||
from ..symbol_tree_builder import SymbolTreeBuilder
|
||||
from ..ast_helpers import AstReplacer, AstModifier
|
||||
from ..common.event import Event
|
||||
from ..namespace import is_subtree
|
||||
|
||||
|
||||
class AssignParser(Parser):
|
||||
|
@ -190,9 +191,6 @@ class AssignParser(Parser):
|
|||
results[keyword.arg] = AssignParser._create_scopedvalue(keyword.value)
|
||||
return results
|
||||
|
||||
def _is_subtree_cell(self, cell: Cell) -> bool:
|
||||
return not type(cell).__name__ in self._cell_namespce
|
||||
|
||||
@staticmethod
|
||||
def _find_op_and_type(func_scope, func_name, stree: SymbolTree):
|
||||
"""
|
||||
|
@ -331,7 +329,7 @@ class AssignParser(Parser):
|
|||
if isinstance(op, Primitive):
|
||||
return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, call_kwargs, func_name)
|
||||
if isinstance(op, Cell):
|
||||
is_sub_tree = self._is_subtree_cell(op)
|
||||
is_sub_tree = is_subtree(type(op).__name__)
|
||||
if is_sub_tree:
|
||||
stb = SymbolTreeBuilder(op)
|
||||
new_stree = stb.build()
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Parse ast.If in construct function to node of SymbolTree."""
|
||||
|
||||
import ast
|
||||
import astunparse
|
||||
|
||||
from ..symbol_tree import SymbolTree
|
||||
from ..parser import Parser
|
||||
from ..parser_register import ParserRegister, reg_parser
|
||||
|
||||
|
||||
class IfParser(Parser):
|
||||
"""Parse ast.If in construct function to node of SymbolTree."""
|
||||
|
||||
def target(self):
|
||||
"""Parse target type"""
|
||||
return ast.If
|
||||
|
||||
def process(self, stree: SymbolTree, node: ast.If):
|
||||
"""
|
||||
Parse ast.If and create a node in symbol tree.
|
||||
|
||||
Args:
|
||||
stree ([SymbolTree]): Symbol Tree under parsing.
|
||||
node ([ast.If]): An ast.If node.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If test of ast.If can not be eval.
|
||||
"""
|
||||
|
||||
test_code = astunparse.unparse(node.test)
|
||||
test_code = test_code.replace("self", "stree.get_origin_network()")
|
||||
bodies = None
|
||||
try:
|
||||
test_value = eval(test_code)
|
||||
if test_value:
|
||||
bodies = node.body
|
||||
else:
|
||||
bodies = node.orelse
|
||||
except Exception:
|
||||
raise NotImplementedError("Only support ast.If whose test can be eval, got:", test_code)
|
||||
|
||||
for body in bodies:
|
||||
parser: Parser = ParserRegister.instance().get_parser(type(body))
|
||||
if parser is None:
|
||||
stree.append_python_node(node, body)
|
||||
else:
|
||||
parser.process(stree, body)
|
||||
|
||||
|
||||
g_if_parser = reg_parser(IfParser())
|
|
@ -131,30 +131,41 @@ class SymbolTree(Observer, Observable):
|
|||
Returns:
|
||||
An instance of Node represents root of input nodes.
|
||||
"""
|
||||
consumers: [ScopedValue] = []
|
||||
consumers: {ScopedValue: [Node]} = {}
|
||||
target_dict: {ScopedValue: Node} = {}
|
||||
for node in nodes:
|
||||
consumers.extend(node.get_args())
|
||||
for arg in node.get_args():
|
||||
if consumers.get(arg):
|
||||
consumers[arg].append(node)
|
||||
else:
|
||||
consumers[arg] = [node]
|
||||
for _, arg in node.get_kwargs():
|
||||
consumers.append(arg)
|
||||
if consumers.get(arg):
|
||||
consumers[arg].append(node)
|
||||
else:
|
||||
consumers[arg] = [node]
|
||||
for target in node.get_targets():
|
||||
if target_dict.get(target) is not None:
|
||||
raise RuntimeError("Target of node duplicated")
|
||||
raise RuntimeError(f"Target({target}) of node duplicated")
|
||||
target_dict[target] = node
|
||||
# find root node
|
||||
root = None
|
||||
for node in nodes:
|
||||
used = 0
|
||||
for target in node.get_targets():
|
||||
if target in consumers:
|
||||
used += 1
|
||||
break
|
||||
cur_consumers = consumers.get(target)
|
||||
if not cur_consumers:
|
||||
continue
|
||||
for cur_consumer in cur_consumers:
|
||||
if id(cur_consumer) != id(node):
|
||||
used += 1
|
||||
break
|
||||
if used == 0:
|
||||
if root is not None:
|
||||
raise RuntimeError("Replacement should only has one root")
|
||||
raise RuntimeError("Replacement should only has one root, found multi-root")
|
||||
root = node
|
||||
if root is None:
|
||||
raise RuntimeError("No root node found in replacement nodes")
|
||||
raise RuntimeError("Replacement should only has one root, found no root")
|
||||
# link node's input
|
||||
for node in nodes:
|
||||
inputs = []
|
||||
|
|
|
@ -0,0 +1,248 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""LeNet."""
|
||||
from collections import OrderedDict
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore.rewrite import SymbolTree, PatternEngine, Replacement, PatternNode, Node, ScopedValue
|
||||
|
||||
|
||||
class IfNet(nn.Cell):
|
||||
def __init__(self, use_se=False, res_base=False):
|
||||
super(IfNet, self).__init__()
|
||||
|
||||
self.use_se = use_se
|
||||
self.res_base = res_base
|
||||
self.se_block = False
|
||||
if self.use_se:
|
||||
self.se_block = True
|
||||
|
||||
if self.use_se:
|
||||
self.conv1_0 = nn.Conv2d(3, 32, 3, stride=2, padding=0, pad_mode='same')
|
||||
self.bn1_0 = nn.BatchNorm2d(32)
|
||||
self.conv1_1 = nn.Conv2d(32, 32, 3, stride=1, padding=0, pad_mode='same')
|
||||
self.bn1_1 = nn.BatchNorm2d(32)
|
||||
self.conv1_2 = nn.Conv2d(32, 64, 3, stride=1, padding=0, pad_mode='same')
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(3, 64, 7, stride=2, padding=0, pad_mode='same')
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = P.ReLU()
|
||||
|
||||
if self.res_base:
|
||||
self.pad = nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)))
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="valid")
|
||||
else:
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
|
||||
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
self.flatten = nn.Flatten()
|
||||
self.end_point = nn.Dense(in_channels=2048, out_channels=10, has_bias=True, bias_init=0)
|
||||
|
||||
def construct(self, x):
|
||||
if self.use_se:
|
||||
x = self.conv1_0(x)
|
||||
x = self.bn1_0(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv1_1(x)
|
||||
x = self.bn1_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv1_2(x)
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
if self.res_base:
|
||||
x = self.pad(x)
|
||||
c1 = self.maxpool(x)
|
||||
|
||||
out = self.mean(c1, (2, 3))
|
||||
out = self.flatten(out)
|
||||
out = self.end_point(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ConvBnReplace(Replacement):
|
||||
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
|
||||
bn_node: Node = matched.get(pattern.name())
|
||||
bn: nn.BatchNorm2d = bn_node.get_instance()
|
||||
conv_p = pattern.get_inputs()[0]
|
||||
conv_node: Node = matched.get(conv_p.name())
|
||||
conv: nn.Conv2d = conv_node.get_instance()
|
||||
newconv = nn.Conv2dBnAct(conv.in_channels,
|
||||
conv.out_channels,
|
||||
conv.kernel_size,
|
||||
conv.stride,
|
||||
conv.pad_mode,
|
||||
conv.padding,
|
||||
conv.dilation,
|
||||
conv.group,
|
||||
conv.has_bias,
|
||||
conv.weight_init,
|
||||
conv.bias_init,
|
||||
True,
|
||||
bn.momentum,
|
||||
bn.eps)
|
||||
newconv_node = Node.create_call_cell(newconv, bn_node.get_targets(), conv_node.get_args(),
|
||||
conv_node.get_kwargs(), "Conv2dBnAct")
|
||||
return [newconv_node]
|
||||
|
||||
|
||||
class ConvBnPattern(PatternEngine):
|
||||
def __init__(self):
|
||||
super().__init__([nn.Conv2d, nn.BatchNorm2d], ConvBnReplace())
|
||||
|
||||
|
||||
def test_resnet_erase_in_if():
|
||||
"""
|
||||
Feature: erase_node api and if_parser
|
||||
Description: erase a node in ast.If.
|
||||
Expectation: Success.
|
||||
"""
|
||||
net = IfNet()
|
||||
stree: SymbolTree = SymbolTree.create(net)
|
||||
original_nodes_size = len(stree.get_handler()._nodes)
|
||||
for node in stree.nodes():
|
||||
node_: Node = node
|
||||
if node_.get_instance_type() == nn.Conv2d:
|
||||
input_ = node_.get_inputs()[0]
|
||||
output = node_.get_users()[0]
|
||||
output.set_arg_by_node(0, input_)
|
||||
stree.erase_node(node)
|
||||
break
|
||||
assert len(stree.get_handler()._nodes) == original_nodes_size - 1
|
||||
|
||||
|
||||
def test_resnet_insert_in_if():
|
||||
"""
|
||||
Feature: insert api and if_parser
|
||||
Description: insert a node into ast.If.
|
||||
Expectation: Success.
|
||||
"""
|
||||
net = IfNet()
|
||||
stree: SymbolTree = SymbolTree.create(net)
|
||||
original_nodes_size = len(stree.get_handler()._nodes)
|
||||
for node in stree.nodes():
|
||||
node_: Node = node
|
||||
if node_.get_instance_type() == nn.Conv2d:
|
||||
pos = stree.after(node_)
|
||||
conv: nn.Conv2d = node_.get_instance()
|
||||
new_bn = Node.create_call_cell(nn.BatchNorm2d(conv.out_channels), targets=["x"],
|
||||
args=[ScopedValue.create_naming_value("x")], kwargs={}, name="new_bn")
|
||||
stree.insert(pos, new_bn)
|
||||
break
|
||||
assert len(stree.get_handler()._nodes) == original_nodes_size + 1
|
||||
|
||||
|
||||
def test_resnet_replace_121_in_if():
|
||||
"""
|
||||
Feature: replace api and if_parser
|
||||
Description: Replace one node by one nodes in ast.If.
|
||||
Expectation: Success.
|
||||
"""
|
||||
net = IfNet()
|
||||
stree: SymbolTree = SymbolTree.create(net)
|
||||
original_nodes_size = len(stree.get_handler()._nodes)
|
||||
for node in stree.nodes():
|
||||
node_: Node = node
|
||||
if node_.get_instance_type() == nn.Conv2d:
|
||||
conv: nn.Conv2d = node_.get_instance()
|
||||
new_conv = Node.create_call_cell(nn.Conv2d(conv.in_channels, conv.out_channels, conv.kernel_size),
|
||||
targets=node_.get_targets(), args=node_.get_args(),
|
||||
kwargs=node.get_kwargs(), name="new_conv")
|
||||
stree.replace(node_, [new_conv])
|
||||
break
|
||||
assert len(stree.get_handler()._nodes) == original_nodes_size
|
||||
|
||||
|
||||
def test_resnet_replace_12m_in_if():
|
||||
"""
|
||||
Feature: replace api and if_parser
|
||||
Description: Replace one node by multi-nodes in ast.If.
|
||||
Expectation: Success.
|
||||
"""
|
||||
net = IfNet()
|
||||
stree: SymbolTree = SymbolTree.create(net)
|
||||
original_nodes_size = len(stree.get_handler()._nodes)
|
||||
for node in stree.nodes():
|
||||
node_: Node = node
|
||||
if node_.get_instance_type() == nn.Conv2d:
|
||||
conv: nn.Conv2d = node_.get_instance()
|
||||
new_conv = Node.create_call_cell(nn.Conv2d(conv.in_channels, conv.out_channels, conv.kernel_size),
|
||||
targets=["x"], args=node_.get_args(),
|
||||
kwargs=node.get_kwargs(), name="new_conv")
|
||||
new_bn = Node.create_call_cell(nn.BatchNorm2d(conv.out_channels),
|
||||
targets=node_.get_targets(), args=[ScopedValue.create_naming_value("x")],
|
||||
kwargs={}, name="new_bn")
|
||||
stree.replace(node_, [new_conv, new_bn])
|
||||
break
|
||||
assert len(stree.get_handler()._nodes) == original_nodes_size + 1
|
||||
|
||||
|
||||
def test_resnet_fusion_in_if():
|
||||
"""
|
||||
Feature: PatternEngine and if_parser
|
||||
Description: Apply PatternEngine on nodes in ast.If.
|
||||
Expectation: Success.
|
||||
"""
|
||||
net = IfNet()
|
||||
stree: SymbolTree = SymbolTree.create(net)
|
||||
original_nodes_size = len(stree.get_handler()._nodes)
|
||||
for node in stree.nodes():
|
||||
node_: Node = node
|
||||
if node_.get_instance_type() == nn.Conv2d:
|
||||
old_bn = node_.get_users()[0]
|
||||
pos = stree.after(node_)
|
||||
conv: nn.Conv2d = node_.get_instance()
|
||||
new_bn = Node.create_call_cell(nn.BatchNorm2d(conv.out_channels), targets=["x"],
|
||||
args=[node_.get_targets()[0]], kwargs={}, name="new_bn")
|
||||
stree.insert(pos, new_bn)
|
||||
old_bn.set_arg_by_node(0, new_bn)
|
||||
break
|
||||
assert len(stree.get_handler()._nodes) == original_nodes_size + 1
|
||||
ConvBnPattern().apply(stree)
|
||||
print(stree.get_code())
|
||||
assert len(stree.get_handler()._nodes) == original_nodes_size
|
||||
assert not stree.get_node("conv1")
|
||||
assert not stree.get_node("new_bn")
|
||||
assert stree.get_node("bn1")
|
||||
|
||||
|
||||
def test_resnet_fusion_cross_if():
|
||||
"""
|
||||
Feature: PatternEngine and if_parser
|
||||
Description: Apply PatternEngine on nodes cross ast.If.
|
||||
Expectation: Success.
|
||||
"""
|
||||
net = IfNet()
|
||||
stree: SymbolTree = SymbolTree.create(net)
|
||||
original_nodes_size = len(stree.get_handler()._nodes)
|
||||
for node in stree.nodes():
|
||||
node_: Node = node
|
||||
if node_.get_instance_type() == nn.Conv2d:
|
||||
pos = stree.after(node_)
|
||||
conv: nn.Conv2d = node_.get_instance()
|
||||
new_bn = Node.create_call_cell(nn.BatchNorm2d(conv.out_channels), targets=["x"],
|
||||
args=[ScopedValue.create_naming_value("x")], kwargs={}, name="new_bn")
|
||||
stree.insert(pos, new_bn)
|
||||
break
|
||||
assert len(stree.get_handler()._nodes) == original_nodes_size + 1
|
||||
ConvBnPattern().apply(stree)
|
||||
print(stree.get_code())
|
||||
assert len(stree.get_handler()._nodes) == original_nodes_size
|
||||
assert not stree.get_node("conv1")
|
||||
assert stree.get_node("new_bn")
|
||||
assert not stree.get_node("bn1")
|
|
@ -357,6 +357,25 @@ def test_replace_one_to_one():
|
|||
assert new_conv_node.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
|
||||
|
||||
|
||||
def test_replace_one_to_one_with_same_arg_and_target():
|
||||
"""
|
||||
Feature: Python api replace of SymbolTree of Rewrite.
|
||||
Description: Call replace to replace an origin node to a new node whose arg and target are same.
|
||||
Expectation: Success.
|
||||
"""
|
||||
stree, _, relu1, _ = create_symbol_tree()
|
||||
construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
|
||||
assert len(construct_ast.body) == 6
|
||||
assert get_symbol_tree_nodes_count(stree) == 7
|
||||
|
||||
new_conv = Conv2d(16, 16, 5)
|
||||
new_conv_node = NodeApi.create_call_cell(new_conv, [ScopedValue.create_naming_value("new_conv")],
|
||||
[ScopedValue.create_naming_value("new_conv")]).get_handler()
|
||||
stree.replace(relu1, [new_conv_node])
|
||||
assert get_symbol_tree_nodes_count(stree) == 7
|
||||
assert stree.get_node("new_conv")
|
||||
|
||||
|
||||
def test_replace_one_to_multi():
|
||||
"""
|
||||
Feature: Python api replace of SymbolTree of Rewrite.
|
||||
|
|
Loading…
Reference in New Issue