forked from mindspore-Ecosystem/mindspore
[MS][Rewrite]Move return out of if and support multiple targets
[MS][Rewrite]Move return out of if and support multiple targets [MS][Rewrite]Move return out of if and support multiple targets [MS][Rewrite]Move return out of if and support multiple targets [MS][Rewrite]Move return out of if and support multiple targets [MS][Rewrite]Move return out of if and support multiple targets [MS][Rewrite]Move return out of if and support multiple targets [MS][Rewrite]Move return out of if and support multiple targets [MS][Rewrite]Move return out of if and support multiple targets
This commit is contained in:
parent
3906a598ca
commit
317f59a0ae
|
@ -206,6 +206,7 @@ install(
|
|||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/communication
|
||||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/profiler
|
||||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/compression
|
||||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/rewrite
|
||||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/run_check
|
||||
DESTINATION ${INSTALL_PY_DIR}
|
||||
COMPONENT mindspore
|
||||
|
|
|
@ -342,8 +342,8 @@ class AstModifier(ast.NodeTransformer):
|
|||
dst_ast.value = src_argument.value
|
||||
return
|
||||
if isinstance(dst_ast, ast.Name):
|
||||
if src_argument.type != ValueType.NamingValue:
|
||||
raise RuntimeError("src_argument.type should equal to ValueType.NamingValue")
|
||||
if src_argument.type not in [ValueType.NamingValue, ValueType.StringValue]:
|
||||
raise RuntimeError("src_argument.type should be ValueType.NamingValue or ValueType.StringValue.")
|
||||
if src_argument.scope:
|
||||
raise RuntimeError("src_argument.scope should be empty")
|
||||
dst_ast.id = src_argument.value
|
||||
|
|
|
@ -14,3 +14,4 @@
|
|||
# ============================================================================
|
||||
"""Transformers for optimizing ast."""
|
||||
from .flatten_recursive_stmt import FlattenRecursiveStmt
|
||||
from .remove_return_out_of_if import RemoveReturnOutOfIf
|
||||
|
|
|
@ -53,6 +53,8 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|||
target_name = "return_value"
|
||||
elif isinstance(node, (ast.BinOp, ast.boolop, ast.UnaryOp)):
|
||||
target_name = type(node.op).__name__
|
||||
elif isinstance(node, ast.Tuple):
|
||||
target_name = type(node).__name__
|
||||
else:
|
||||
logger.warning("unhandled type of node while generating new target name: %s ", type(node))
|
||||
target_name = type(node).__name__
|
||||
|
@ -64,21 +66,6 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|||
target_names.append(result)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _fill_in_original_target_names(target_names, node):
|
||||
"""Fill in original target names before getting unique names."""
|
||||
for function_index in range(len(node.body)):
|
||||
child = node.body[function_index]
|
||||
if not isinstance(child, ast.Assign):
|
||||
continue
|
||||
targets = child.targets
|
||||
for target in targets:
|
||||
if not isinstance(target, ast.Name):
|
||||
raise RuntimeError("currently only support ast.Name targets")
|
||||
target_name = target.id
|
||||
if target_name not in target_names:
|
||||
target_names.append(target_name)
|
||||
|
||||
@staticmethod
|
||||
def _create_new_assign_node(node: ast.AST, target_names) -> Tuple[str, ast.AST]:
|
||||
"""Create new assign node to be inserted into ast.FunctionDef."""
|
||||
|
@ -122,6 +109,35 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|||
results.append(new_node)
|
||||
return results
|
||||
|
||||
def _fill_in_original_target_names(self, target_names, node):
|
||||
"""Fill in original target names before getting unique names."""
|
||||
for function_index in range(len(node.body)):
|
||||
child = node.body[function_index]
|
||||
if isinstance(child, (ast.Assign, ast.Expr)):
|
||||
child_value = child.value
|
||||
else:
|
||||
child_value = child
|
||||
if not self._flatten_table.get(type(child_value)):
|
||||
continue
|
||||
|
||||
if not isinstance(child, ast.Assign):
|
||||
continue
|
||||
targets = child.targets
|
||||
for target in targets:
|
||||
if not isinstance(target, (ast.Name, ast.Tuple)):
|
||||
raise RuntimeError("currently only support ast.Name targets")
|
||||
if isinstance(target, ast.Name):
|
||||
target_name = target.id
|
||||
if target_name not in target_names:
|
||||
target_names.append(target_name)
|
||||
elif isinstance(target, ast.Tuple):
|
||||
for elt in target.elts:
|
||||
if not isinstance(elt, ast.Name):
|
||||
raise RuntimeError("currently only support ast.Name in ast.Tuple.")
|
||||
target_name = elt.id
|
||||
if target_name not in target_names:
|
||||
target_names.append(target_name)
|
||||
|
||||
def visit_FunctionDef(self, node: FunctionDef) -> Any:
|
||||
"""Traverse construct node and flatten recursive nodes."""
|
||||
if node.name != "construct":
|
||||
|
|
|
@ -0,0 +1,225 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Fold if return."""
|
||||
|
||||
import ast
|
||||
import copy
|
||||
from typing import Any, Union
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ReturnType(Enum):
|
||||
"""
|
||||
ValueType represents type of nodes.
|
||||
|
||||
- A `NotReturn` represents the node is not a return node.
|
||||
- A `IfNotAllReturn` represents the node is an ast.If node and not all branches of it end with return.
|
||||
- A `Return` represents the node is a return node or an ast.If node of which all branches of it end with return.
|
||||
"""
|
||||
NotReturn = 0
|
||||
IfNotAllReturn = 1
|
||||
Return = 2
|
||||
|
||||
|
||||
class RemoveReturnOutOfIf(ast.NodeTransformer):
|
||||
"""
|
||||
Ast optimizer for removing all returns out of if control flow.
|
||||
|
||||
Example one:
|
||||
def func(x):
|
||||
if x == 1:
|
||||
return x - 1
|
||||
x += 1
|
||||
return x
|
||||
|
||||
will be optimized to
|
||||
def func(x):
|
||||
if x == 1:
|
||||
output_0 = x - 1
|
||||
else:
|
||||
x += 1
|
||||
output_0 = x
|
||||
return output_0
|
||||
|
||||
Example two:
|
||||
def func(x):
|
||||
if x == 1:
|
||||
x = 0
|
||||
elif x == 2:
|
||||
x = 4
|
||||
else:
|
||||
return x
|
||||
x += 1
|
||||
return x + 1
|
||||
|
||||
will be optimized to
|
||||
def func(x):
|
||||
if x == 1:
|
||||
x = 0
|
||||
x += 1
|
||||
output_1 = x + 1
|
||||
else:
|
||||
if x == 2:
|
||||
x = 4
|
||||
x += 1
|
||||
output_0 = x + 1
|
||||
else:
|
||||
output_0 = x
|
||||
output_1 = output_0
|
||||
return output_1
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _last_node_is_return(node: Union[ast.Return, ast.If]) -> ReturnType:
|
||||
"""
|
||||
Judge whether input node represents a return node.
|
||||
Return a numeric value according to different cases:
|
||||
0: Input node is not an ast.Return or input node is an ast.If of which all branches not end with ast.Return;
|
||||
1: Input node is an ast.If and not all branches end with ast.Return;
|
||||
2: Input node is an ast.Return or input node is an ast.If of which all branches end with ast.Return.
|
||||
"""
|
||||
if not isinstance(node, ast.Return) and not isinstance(node, ast.If):
|
||||
return ReturnType.NotReturn
|
||||
if isinstance(node, ast.Return): # last node is ast.Return
|
||||
return ReturnType.Return
|
||||
# all branches of ast.If not end with ast.Return
|
||||
if node.body and RemoveReturnOutOfIf._last_node_is_return(node.body[-1]) == ReturnType.NotReturn \
|
||||
and (not node.orelse or RemoveReturnOutOfIf._last_node_is_return(node.orelse[-1]) ==
|
||||
ReturnType.NotReturn):
|
||||
return ReturnType.NotReturn
|
||||
# all branches of ast.If end with ast.Return
|
||||
if node.body and RemoveReturnOutOfIf._last_node_is_return(node.body[-1]) == ReturnType.Return \
|
||||
and node.orelse and RemoveReturnOutOfIf._last_node_is_return(node.orelse[-1]) == ReturnType.Return:
|
||||
return ReturnType.Return
|
||||
# not all branches of ast.If end with ast.Return
|
||||
return ReturnType.IfNotAllReturn
|
||||
|
||||
@staticmethod
|
||||
def _fold_return(father_node: Union[ast.FunctionDef, ast.If], if_node: ast.If, if_index: int, attr: str):
|
||||
"""
|
||||
Fold following nodes into if node when not all branches of ast.If end with ast.Return.
|
||||
|
||||
Args:
|
||||
father_node (Union[ast.FunctionDef, ast.If]): Father node.
|
||||
if_node (ast.If): A if node.
|
||||
if_index (int): Index of the if node in body or or-else of father node.
|
||||
attr (str): Attribute of father node, can be 'body' or 'orelse'.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Father node has not input attr.
|
||||
"""
|
||||
if not hasattr(father_node, attr):
|
||||
raise RuntimeError('Father node has not input attr', attr)
|
||||
father_node_attr = getattr(father_node, attr)
|
||||
if RemoveReturnOutOfIf._last_node_is_return(if_node) == ReturnType.IfNotAllReturn:
|
||||
# nodes should be copied to all branches which not end with return
|
||||
if if_node.body and RemoveReturnOutOfIf._last_node_is_return(if_node.body[-1]) != ReturnType.Return:
|
||||
for index in range(if_index + 1, len(father_node_attr)):
|
||||
node = copy.deepcopy(father_node_attr[index])
|
||||
if_node.body.append(node)
|
||||
if not if_node.orelse or (if_node.orelse and RemoveReturnOutOfIf._last_node_is_return(if_node.orelse[-1])
|
||||
!= ReturnType.Return):
|
||||
for index in range(if_index + 1, len(father_node_attr)):
|
||||
node = copy.deepcopy(father_node_attr[index])
|
||||
if_node.orelse.append(node)
|
||||
# delete original nodes
|
||||
remove_num = len(father_node_attr) - if_index - 1
|
||||
for _ in range(remove_num):
|
||||
father_node_attr.pop()
|
||||
|
||||
@staticmethod
|
||||
def _fold(father_node: Union[ast.FunctionDef, ast.If], attr: str):
|
||||
"""Fold nodes. Iterate into body and orelse of if node."""
|
||||
if not hasattr(father_node, attr) or not getattr(father_node, attr):
|
||||
return
|
||||
|
||||
if isinstance(getattr(father_node, attr)[-1], ast.If):
|
||||
RemoveReturnOutOfIf._fold(getattr(father_node, attr)[-1], 'body') # if.body
|
||||
RemoveReturnOutOfIf._fold(getattr(father_node, attr)[-1], 'orelse') # if.orelse
|
||||
|
||||
cur_index = len(getattr(father_node, attr)) - 2 # no following nodes to fold when if node is the last one
|
||||
while cur_index >= 0:
|
||||
child = getattr(father_node, attr)[cur_index]
|
||||
if isinstance(child, ast.If):
|
||||
RemoveReturnOutOfIf._fold_return(father_node, child, cur_index, attr)
|
||||
RemoveReturnOutOfIf._fold(child, 'body') # if.body
|
||||
RemoveReturnOutOfIf._fold(child, 'orelse') # if.orelse
|
||||
cur_index -= 1
|
||||
|
||||
@staticmethod
|
||||
def _get_output_names(output_names: [str]):
|
||||
"""Generate unique output names."""
|
||||
name: str = 'output_{}'.format(len(output_names))
|
||||
output_names.append(name)
|
||||
return name
|
||||
|
||||
@staticmethod
|
||||
def _move_out_return(output_names: [str], father_node: Union[ast.FunctionDef, ast.If], attr: str):
|
||||
"""
|
||||
Move all return node out of if nodes.
|
||||
Replace all original return nodes in ast.If with ast.Assign nodes which represent 'output = return value'.
|
||||
And add new ast.Return node to the end of father node.
|
||||
|
||||
Args:
|
||||
output_names ([str]): All unique output names.
|
||||
father_node (Union[ast.FunctionDef, ast.If]): Father node.
|
||||
attr (str): Attribute of father nodes, can be 'body' or 'orelse'.
|
||||
|
||||
Raises:
|
||||
RuntimeError: After iterative processing body and orelse of if nodes not all end with ast.Return.
|
||||
"""
|
||||
if not hasattr(father_node, attr) or not getattr(father_node, attr):
|
||||
return
|
||||
|
||||
last_node = getattr(father_node, attr)[-1]
|
||||
if isinstance(last_node, ast.If) and RemoveReturnOutOfIf._last_node_is_return(last_node) == ReturnType.Return:
|
||||
# the body or orelse of last if node should be ast.Return or ast.If
|
||||
if isinstance(last_node.body[-1], ast.If):
|
||||
RemoveReturnOutOfIf._move_out_return(output_names, last_node, 'body')
|
||||
if isinstance(last_node.orelse[-1], ast.If):
|
||||
RemoveReturnOutOfIf._move_out_return(output_names, last_node, 'orelse')
|
||||
|
||||
# assert body and or-else all end with return
|
||||
if not isinstance(last_node.body[-1], ast.Return) or not isinstance(last_node.orelse[-1], ast.Return):
|
||||
raise RuntimeError("Body and orelse of if nodes not all end with ast.Return.")
|
||||
output_name = RemoveReturnOutOfIf._get_output_names(output_names)
|
||||
# replace body return
|
||||
body_new_last_node = ast.Assign(
|
||||
targets=[ast.Name(id=output_name, ctx=ast.Store())], value=last_node.body[-1].value)
|
||||
last_node.body.pop()
|
||||
last_node.body.append(body_new_last_node)
|
||||
# replace else return
|
||||
else_new_last_node = ast.Assign(
|
||||
targets=[ast.Name(id=output_name, ctx=ast.Store())], value=last_node.orelse[-1].value)
|
||||
last_node.orelse.pop()
|
||||
last_node.orelse.append(else_new_last_node)
|
||||
# add new return node
|
||||
new_return_node = ast.Return(value=ast.Name(id=output_name, cts=ast.Store()))
|
||||
getattr(father_node, attr).append(new_return_node)
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||
"""Iterate construct node and fold following nodes into if node when condition is met."""
|
||||
if node.name != "construct":
|
||||
return node
|
||||
RemoveReturnOutOfIf._fold(node, 'body')
|
||||
output_names = []
|
||||
RemoveReturnOutOfIf._move_out_return(output_names, node, 'body')
|
||||
return node
|
||||
|
||||
def transform(self, ast_root):
|
||||
"""Transform."""
|
||||
ast_root = self.visit(ast_root)
|
||||
ast_root = ast.fix_missing_locations(ast_root)
|
||||
return ast_root
|
|
@ -526,13 +526,21 @@ class Node:
|
|||
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
|
||||
# update targets
|
||||
targets_ast = assign_ast.targets
|
||||
if len(self._targets) != len(targets_ast):
|
||||
if isinstance(targets_ast[0], ast.Tuple) and len(self._targets) != len(targets_ast[0].elts):
|
||||
raise RuntimeError("self._targets should have the same length as targets_ast's elts")
|
||||
if not isinstance(targets_ast[0], ast.Tuple) and len(self._targets) != len(targets_ast):
|
||||
raise RuntimeError("self._targets should have targets_ast same length")
|
||||
for i in range(0, len(self._targets)):
|
||||
target = self._targets[i]
|
||||
target_ast = targets_ast[i]
|
||||
if not isinstance(target_ast, ast.Name):
|
||||
raise TypeError("target_ast should be ast.Name, got: ", type(target_ast))
|
||||
target_ast = targets_ast[0]
|
||||
if isinstance(target_ast, ast.Name):
|
||||
target_ast.id = target.value
|
||||
elif isinstance(target_ast, ast.Tuple):
|
||||
if not isinstance(target_ast.elts[i], ast.Name):
|
||||
raise TypeError("target should be ast.Name, got:", type(target_ast.elts[i]))
|
||||
target_ast.elts[i].id = target.value
|
||||
else:
|
||||
raise TypeError("target_ast should be ast.Name or ast.Tuple, got: ", type(target_ast))
|
||||
target_ast.id = target.value
|
||||
ast.fix_missing_locations(assign_ast)
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Parse ast.Assign in construct function to node of SymbolTree."""
|
||||
from typing import Union
|
||||
import ast
|
||||
import astunparse
|
||||
|
||||
|
@ -24,7 +25,7 @@ from ..symbol_tree import SymbolTree
|
|||
from ..node import Node, TreeNode
|
||||
from ..parser import Parser
|
||||
from ..parser_register import reg_parser
|
||||
from ..api.scoped_value import ScopedValue
|
||||
from ..api.scoped_value import ScopedValue, ValueType
|
||||
from ..symbol_tree_builder import SymbolTreeBuilder
|
||||
from ..ast_helpers import AstReplacer, AstModifier
|
||||
|
||||
|
@ -59,9 +60,12 @@ class AssignParser(Parser):
|
|||
tuple_elts = node.elts
|
||||
tuple_values = []
|
||||
for tuple_elt in tuple_elts:
|
||||
if not isinstance(tuple_elt, ast.Constant):
|
||||
raise RuntimeError("Only support ast.Constant as elts of ast.Tuple.")
|
||||
tuple_values.append(tuple_elt.value)
|
||||
if not isinstance(tuple_elt, (ast.Constant, ast.Name)):
|
||||
raise RuntimeError("Only support ast.Constant or ast.Name as elts of ast.Tuple.")
|
||||
if isinstance(tuple_elt, ast.Constant):
|
||||
tuple_values.append(tuple_elt.value)
|
||||
elif isinstance(tuple_elt, ast.Name):
|
||||
tuple_values.append(tuple_elt.id)
|
||||
return ScopedValue.create_variable_value(tuple(tuple_values))
|
||||
|
||||
@staticmethod
|
||||
|
@ -111,7 +115,9 @@ class AssignParser(Parser):
|
|||
return func.id
|
||||
if isinstance(func, ast.Attribute):
|
||||
return func.attr
|
||||
raise RuntimeError("FuncValue is should be Name or a Attribute:", astunparse.unparse(func))
|
||||
if isinstance(func, ast.Call):
|
||||
return AssignParser._get_func_name(func)
|
||||
raise RuntimeError("FuncValue is should be Name or a Attribute or a Call:", astunparse.unparse(func))
|
||||
|
||||
@staticmethod
|
||||
def _get_func_scope(ast_node: ast.Call) -> str:
|
||||
|
@ -136,7 +142,9 @@ class AssignParser(Parser):
|
|||
if not isinstance(value, ast.Name):
|
||||
raise RuntimeError("FuncValue is should be Name:", ast.dump(func))
|
||||
return value.id
|
||||
raise RuntimeError("FuncValue is should be Name or a Attribute:", ast.dump(func))
|
||||
if isinstance(func, ast.Call):
|
||||
return AssignParser._get_func_scope(func)
|
||||
raise RuntimeError("FuncValue should be Name or a Attribute or a Call:", ast.dump(func))
|
||||
|
||||
@staticmethod
|
||||
def _get_symbol_object(symbol_name, origin_net):
|
||||
|
@ -206,6 +214,19 @@ class AssignParser(Parser):
|
|||
return type(value), value
|
||||
return type(None), None
|
||||
|
||||
@staticmethod
|
||||
def _get_targets(all_targets: ScopedValue) -> [Union[ScopedValue, str]]:
|
||||
"""Get targets from tuple or single value."""
|
||||
targets: [Union[ScopedValue, str]] = []
|
||||
if all_targets.type == ValueType.TupleValue:
|
||||
for single_target in all_targets.value:
|
||||
if not isinstance(single_target, ScopedValue) and not isinstance(single_target.value, str):
|
||||
raise RuntimeError("Only support str target in tuple.")
|
||||
targets.append(single_target)
|
||||
else:
|
||||
targets.append(all_targets)
|
||||
return targets
|
||||
|
||||
def _update_field_in_init(self, func_scope, func_name, stree: SymbolTree, sub_tree: SymbolTree):
|
||||
"""
|
||||
When node is an invoking to sub-network, update value of ast.Assign of corresponding field in `__init__` method.
|
||||
|
@ -269,7 +290,7 @@ class AssignParser(Parser):
|
|||
Raises:
|
||||
RuntimeError: If operator instance invoked by assign is undefined.
|
||||
"""
|
||||
target = AssignParser._create_scopedvalue(father_ast_node.targets[0])
|
||||
targets = AssignParser._get_targets(AssignParser._create_scopedvalue(father_ast_node.targets[0]))
|
||||
func_name = AssignParser._get_func_name(ast_node)
|
||||
if func_name is None or func_name == "":
|
||||
raise RuntimeError("function name not exist")
|
||||
|
@ -283,7 +304,7 @@ class AssignParser(Parser):
|
|||
raise RuntimeError("Operator instance undefined: '", ast.unparse(ast_node.func), "' of '",
|
||||
ast.unparse(ast_node), "'")
|
||||
if isinstance(op, Primitive):
|
||||
return Node.create_call_buildin_op(op, father_ast_node, [target], func, call_args, call_kwargs, func_name)
|
||||
return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, call_kwargs, func_name)
|
||||
if isinstance(op, Cell):
|
||||
is_sub_tree = self._is_subtree_cell(op)
|
||||
if is_sub_tree:
|
||||
|
@ -292,9 +313,9 @@ class AssignParser(Parser):
|
|||
self._update_field_in_init(func_scope, func_name, stree, new_stree)
|
||||
replacer = AstReplacer(new_stree.get_class_ast())
|
||||
replacer.replace_all(new_stree.get_ori_cls_name(), new_stree.get_opt_cls_name())
|
||||
return TreeNode(new_stree, father_ast_node, [target], func, call_args, call_kwargs, func_name,
|
||||
return TreeNode(new_stree, father_ast_node, targets, func, call_args, call_kwargs, func_name,
|
||||
new_stree.get_origin_network())
|
||||
return Node.create_call_buildin_op(op, father_ast_node, [target], func, call_args, call_kwargs, func_name)
|
||||
return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, call_kwargs, func_name)
|
||||
raise RuntimeError("Only support Cell operator or Primitive operator, got ", type(op).__name__)
|
||||
|
||||
def process(self, stree: SymbolTree, node: ast.Assign):
|
||||
|
@ -332,9 +353,9 @@ class AssignParser(Parser):
|
|||
node_name = "constant_assign"
|
||||
else:
|
||||
node_name = "attribute_assign"
|
||||
target = AssignParser._create_scopedvalue(node.targets[0])
|
||||
targets = AssignParser._get_targets(AssignParser._create_scopedvalue(node.targets[0]))
|
||||
call_args = [AssignParser._create_scopedvalue(value)]
|
||||
node_ = Node.create_call_pass_through_method(node, [target], call_args, {}, node_name)
|
||||
node_ = Node.create_call_pass_through_method(node, targets, call_args, {}, node_name)
|
||||
stree.append_origin_field(node_)
|
||||
elif isinstance(value, (ast.List, ast.Tuple, ast.Dict)):
|
||||
# add these as callmethod node if necessary
|
||||
|
|
|
@ -23,7 +23,7 @@ from .symbol_tree import SymbolTree
|
|||
from .node import TreeNode
|
||||
from .parser_register import ParserRegister
|
||||
from .parser import Parser
|
||||
from .ast_transformers import FlattenRecursiveStmt
|
||||
from .ast_transformers import FlattenRecursiveStmt, RemoveReturnOutOfIf
|
||||
from .ast_helpers import AstModifier
|
||||
from .ast_helpers import AstFinder
|
||||
|
||||
|
@ -55,7 +55,7 @@ class SymbolTreeBuilder:
|
|||
Returns:
|
||||
An instance of ast been optimized.
|
||||
"""
|
||||
transform_list = [FlattenRecursiveStmt()]
|
||||
transform_list = [FlattenRecursiveStmt(), RemoveReturnOutOfIf()]
|
||||
for transformer in transform_list:
|
||||
ast_root = transformer.transform(ast_root)
|
||||
return ast_root
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from mindspore.nn import Cell, Conv2d
|
||||
from mindspore.rewrite import SymbolTree
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class SubNet(Cell):
|
||||
"""Sample cell which returns multiple features."""
|
||||
def __init__(self):
|
||||
"""Init."""
|
||||
super().__init__()
|
||||
self.conv = Conv2d(1, 10, 3)
|
||||
|
||||
def construct(self, x):
|
||||
"""Construct."""
|
||||
c1 = self.conv(x)
|
||||
c2 = self.conv(c1)
|
||||
c3 = self.conv(c2)
|
||||
return c1, c2, c3
|
||||
|
||||
|
||||
class NetMultiTargets(Cell):
|
||||
"""Test cls for multiple targets."""
|
||||
def __init__(self):
|
||||
"""Init."""
|
||||
super(NetMultiTargets, self).__init__()
|
||||
self.conv1 = SubNet()
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, x):
|
||||
"""Construct."""
|
||||
c1, c2, c3 = self.conv1(x)
|
||||
x = self.add(c1, c2)
|
||||
x = self.add(x, c3)
|
||||
return x
|
||||
|
||||
|
||||
def test_multi_targets():
|
||||
"""
|
||||
Feature: Test multi-targets.
|
||||
Description: Test multi-targets.
|
||||
Expectation: Success.
|
||||
"""
|
||||
test_cls = NetMultiTargets()
|
||||
stree = SymbolTree.create(test_cls)
|
||||
node = stree.nodes()[2]
|
||||
targets = node.get_targets()
|
||||
assert targets[0].value == 'c1'
|
||||
assert targets[1].value == 'c2'
|
||||
assert targets[2].value == 'c3'
|
|
@ -0,0 +1,201 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import inspect
|
||||
import ast
|
||||
import astunparse
|
||||
|
||||
from mindspore.rewrite.ast_transformers import RemoveReturnOutOfIf
|
||||
|
||||
|
||||
class TestIf:
|
||||
"""Simple test."""
|
||||
@staticmethod
|
||||
def construct(x):
|
||||
"""construct"""
|
||||
if x > 2:
|
||||
return x - 2
|
||||
return x
|
||||
|
||||
|
||||
class TestIf2:
|
||||
"""Test multiple if and test if in if."""
|
||||
@staticmethod
|
||||
def construct(x):
|
||||
"""construct"""
|
||||
if x > 2:
|
||||
return x
|
||||
x += 2
|
||||
if x > 2:
|
||||
if x > 2:
|
||||
return x
|
||||
x += 2
|
||||
return x
|
||||
x *= 2
|
||||
return x
|
||||
|
||||
|
||||
class TestIf3:
|
||||
"""Test else."""
|
||||
@staticmethod
|
||||
def construct(x):
|
||||
"""construct"""
|
||||
if x > 2:
|
||||
x -= 2
|
||||
else:
|
||||
return x
|
||||
x -= 2
|
||||
return x
|
||||
|
||||
|
||||
class TestIf4:
|
||||
"""Test elif."""
|
||||
@staticmethod
|
||||
def construct(x):
|
||||
"""construct"""
|
||||
if x > 2:
|
||||
return x
|
||||
x += 2
|
||||
if x > 2:
|
||||
x += 1
|
||||
if x > 2:
|
||||
x *= 2
|
||||
elif x > 3:
|
||||
x /= 3
|
||||
else:
|
||||
return x
|
||||
x += 2
|
||||
return x
|
||||
x *= 2
|
||||
return x
|
||||
|
||||
|
||||
def test_simple_if():
|
||||
"""
|
||||
Feature: Test remove return from simple if.
|
||||
Description: Test remove return from simple if.
|
||||
Expectation: Success.
|
||||
"""
|
||||
ast_root: ast.Module = ast.parse(inspect.getsource(TestIf))
|
||||
folder = RemoveReturnOutOfIf()
|
||||
folder.transform(ast_root)
|
||||
assert astunparse.unparse(ast_root) == """\n\nclass TestIf():
|
||||
'Simple test.'\n
|
||||
@staticmethod
|
||||
def construct(x):
|
||||
'construct'
|
||||
if (x > 2):
|
||||
output_0 = (x - 2)
|
||||
else:
|
||||
output_0 = x
|
||||
return output_0
|
||||
"""
|
||||
|
||||
|
||||
def test_multiple_if():
|
||||
"""
|
||||
Feature: Test remove return from multiple if.
|
||||
Description: Test remove return from multiple if.
|
||||
Expectation: Success.
|
||||
"""
|
||||
ast_root: ast.Module = ast.parse(inspect.getsource(TestIf2))
|
||||
folder = RemoveReturnOutOfIf()
|
||||
folder.transform(ast_root)
|
||||
assert astunparse.unparse(ast_root) == """\n\nclass TestIf2():
|
||||
'Test multiple if and test if in if.'\n
|
||||
@staticmethod
|
||||
def construct(x):
|
||||
'construct'
|
||||
if (x > 2):
|
||||
output_2 = x
|
||||
else:
|
||||
x += 2
|
||||
if (x > 2):
|
||||
if (x > 2):
|
||||
output_0 = x
|
||||
else:
|
||||
x += 2
|
||||
output_0 = x
|
||||
output_1 = output_0
|
||||
else:
|
||||
x *= 2
|
||||
output_1 = x
|
||||
output_2 = output_1
|
||||
return output_2
|
||||
"""
|
||||
|
||||
|
||||
def test_else():
|
||||
"""
|
||||
Feature: Test remove return in else of if node.
|
||||
Description: Test remove return in else of if node.
|
||||
Expectation: Success.
|
||||
"""
|
||||
ast_root: ast.Module = ast.parse(inspect.getsource(TestIf3))
|
||||
folder = RemoveReturnOutOfIf()
|
||||
folder.transform(ast_root)
|
||||
assert astunparse.unparse(ast_root) == """\n\nclass TestIf3():
|
||||
'Test else.'\n
|
||||
@staticmethod
|
||||
def construct(x):
|
||||
'construct'
|
||||
if (x > 2):
|
||||
x -= 2
|
||||
x -= 2
|
||||
output_0 = x
|
||||
else:
|
||||
output_0 = x
|
||||
return output_0
|
||||
"""
|
||||
|
||||
|
||||
def test_elif():
|
||||
"""
|
||||
Feature: Test remove return from elif.
|
||||
Description: Test remove return from elif.
|
||||
Expectation: Success.
|
||||
"""
|
||||
ast_root: ast.Module = ast.parse(inspect.getsource(TestIf4))
|
||||
folder = RemoveReturnOutOfIf()
|
||||
folder.transform(ast_root)
|
||||
assert astunparse.unparse(ast_root) == """\n\nclass TestIf4():
|
||||
'Test elif.'\n
|
||||
@staticmethod
|
||||
def construct(x):
|
||||
'construct'
|
||||
if (x > 2):
|
||||
output_3 = x
|
||||
else:
|
||||
x += 2
|
||||
if (x > 2):
|
||||
x += 1
|
||||
if (x > 2):
|
||||
x *= 2
|
||||
x += 2
|
||||
output_1 = x
|
||||
else:
|
||||
if (x > 3):
|
||||
x /= 3
|
||||
x += 2
|
||||
output_0 = x
|
||||
else:
|
||||
output_0 = x
|
||||
output_1 = output_0
|
||||
output_2 = output_1
|
||||
else:
|
||||
x *= 2
|
||||
output_2 = x
|
||||
output_3 = output_2
|
||||
return output_3
|
||||
"""
|
Loading…
Reference in New Issue