[MS][Rewrite]Move return out of if and support multiple targets

[MS][Rewrite]Move return out of if and support multiple targets

[MS][Rewrite]Move return out of if and support multiple targets

[MS][Rewrite]Move return out of if and support multiple targets

[MS][Rewrite]Move return out of if and support multiple targets

[MS][Rewrite]Move return out of if and support multiple targets

[MS][Rewrite]Move return out of if and support multiple targets

[MS][Rewrite]Move return out of if and support multiple targets

[MS][Rewrite]Move return out of if and support multiple targets
This commit is contained in:
xiongkun 2022-04-09 18:36:50 +08:00
parent 3906a598ca
commit 317f59a0ae
10 changed files with 571 additions and 35 deletions

View File

@ -206,6 +206,7 @@ install(
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/communication
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/profiler
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/compression
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/rewrite
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/run_check
DESTINATION ${INSTALL_PY_DIR}
COMPONENT mindspore

View File

@ -342,8 +342,8 @@ class AstModifier(ast.NodeTransformer):
dst_ast.value = src_argument.value
return
if isinstance(dst_ast, ast.Name):
if src_argument.type != ValueType.NamingValue:
raise RuntimeError("src_argument.type should equal to ValueType.NamingValue")
if src_argument.type not in [ValueType.NamingValue, ValueType.StringValue]:
raise RuntimeError("src_argument.type should be ValueType.NamingValue or ValueType.StringValue.")
if src_argument.scope:
raise RuntimeError("src_argument.scope should be empty")
dst_ast.id = src_argument.value

View File

@ -14,3 +14,4 @@
# ============================================================================
"""Transformers for optimizing ast."""
from .flatten_recursive_stmt import FlattenRecursiveStmt
from .remove_return_out_of_if import RemoveReturnOutOfIf

View File

@ -53,6 +53,8 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
target_name = "return_value"
elif isinstance(node, (ast.BinOp, ast.boolop, ast.UnaryOp)):
target_name = type(node.op).__name__
elif isinstance(node, ast.Tuple):
target_name = type(node).__name__
else:
logger.warning("unhandled type of node while generating new target name: %s ", type(node))
target_name = type(node).__name__
@ -64,21 +66,6 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
target_names.append(result)
return result
@staticmethod
def _fill_in_original_target_names(target_names, node):
"""Fill in original target names before getting unique names."""
for function_index in range(len(node.body)):
child = node.body[function_index]
if not isinstance(child, ast.Assign):
continue
targets = child.targets
for target in targets:
if not isinstance(target, ast.Name):
raise RuntimeError("currently only support ast.Name targets")
target_name = target.id
if target_name not in target_names:
target_names.append(target_name)
@staticmethod
def _create_new_assign_node(node: ast.AST, target_names) -> Tuple[str, ast.AST]:
"""Create new assign node to be inserted into ast.FunctionDef."""
@ -122,6 +109,35 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
results.append(new_node)
return results
def _fill_in_original_target_names(self, target_names, node):
"""Fill in original target names before getting unique names."""
for function_index in range(len(node.body)):
child = node.body[function_index]
if isinstance(child, (ast.Assign, ast.Expr)):
child_value = child.value
else:
child_value = child
if not self._flatten_table.get(type(child_value)):
continue
if not isinstance(child, ast.Assign):
continue
targets = child.targets
for target in targets:
if not isinstance(target, (ast.Name, ast.Tuple)):
raise RuntimeError("currently only support ast.Name targets")
if isinstance(target, ast.Name):
target_name = target.id
if target_name not in target_names:
target_names.append(target_name)
elif isinstance(target, ast.Tuple):
for elt in target.elts:
if not isinstance(elt, ast.Name):
raise RuntimeError("currently only support ast.Name in ast.Tuple.")
target_name = elt.id
if target_name not in target_names:
target_names.append(target_name)
def visit_FunctionDef(self, node: FunctionDef) -> Any:
"""Traverse construct node and flatten recursive nodes."""
if node.name != "construct":

View File

@ -0,0 +1,225 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Fold if return."""
import ast
import copy
from typing import Any, Union
from enum import Enum
class ReturnType(Enum):
"""
ValueType represents type of nodes.
- A `NotReturn` represents the node is not a return node.
- A `IfNotAllReturn` represents the node is an ast.If node and not all branches of it end with return.
- A `Return` represents the node is a return node or an ast.If node of which all branches of it end with return.
"""
NotReturn = 0
IfNotAllReturn = 1
Return = 2
class RemoveReturnOutOfIf(ast.NodeTransformer):
"""
Ast optimizer for removing all returns out of if control flow.
Example one:
def func(x):
if x == 1:
return x - 1
x += 1
return x
will be optimized to
def func(x):
if x == 1:
output_0 = x - 1
else:
x += 1
output_0 = x
return output_0
Example two:
def func(x):
if x == 1:
x = 0
elif x == 2:
x = 4
else:
return x
x += 1
return x + 1
will be optimized to
def func(x):
if x == 1:
x = 0
x += 1
output_1 = x + 1
else:
if x == 2:
x = 4
x += 1
output_0 = x + 1
else:
output_0 = x
output_1 = output_0
return output_1
"""
@staticmethod
def _last_node_is_return(node: Union[ast.Return, ast.If]) -> ReturnType:
"""
Judge whether input node represents a return node.
Return a numeric value according to different cases:
0: Input node is not an ast.Return or input node is an ast.If of which all branches not end with ast.Return;
1: Input node is an ast.If and not all branches end with ast.Return;
2: Input node is an ast.Return or input node is an ast.If of which all branches end with ast.Return.
"""
if not isinstance(node, ast.Return) and not isinstance(node, ast.If):
return ReturnType.NotReturn
if isinstance(node, ast.Return): # last node is ast.Return
return ReturnType.Return
# all branches of ast.If not end with ast.Return
if node.body and RemoveReturnOutOfIf._last_node_is_return(node.body[-1]) == ReturnType.NotReturn \
and (not node.orelse or RemoveReturnOutOfIf._last_node_is_return(node.orelse[-1]) ==
ReturnType.NotReturn):
return ReturnType.NotReturn
# all branches of ast.If end with ast.Return
if node.body and RemoveReturnOutOfIf._last_node_is_return(node.body[-1]) == ReturnType.Return \
and node.orelse and RemoveReturnOutOfIf._last_node_is_return(node.orelse[-1]) == ReturnType.Return:
return ReturnType.Return
# not all branches of ast.If end with ast.Return
return ReturnType.IfNotAllReturn
@staticmethod
def _fold_return(father_node: Union[ast.FunctionDef, ast.If], if_node: ast.If, if_index: int, attr: str):
"""
Fold following nodes into if node when not all branches of ast.If end with ast.Return.
Args:
father_node (Union[ast.FunctionDef, ast.If]): Father node.
if_node (ast.If): A if node.
if_index (int): Index of the if node in body or or-else of father node.
attr (str): Attribute of father node, can be 'body' or 'orelse'.
Raises:
RuntimeError: Father node has not input attr.
"""
if not hasattr(father_node, attr):
raise RuntimeError('Father node has not input attr', attr)
father_node_attr = getattr(father_node, attr)
if RemoveReturnOutOfIf._last_node_is_return(if_node) == ReturnType.IfNotAllReturn:
# nodes should be copied to all branches which not end with return
if if_node.body and RemoveReturnOutOfIf._last_node_is_return(if_node.body[-1]) != ReturnType.Return:
for index in range(if_index + 1, len(father_node_attr)):
node = copy.deepcopy(father_node_attr[index])
if_node.body.append(node)
if not if_node.orelse or (if_node.orelse and RemoveReturnOutOfIf._last_node_is_return(if_node.orelse[-1])
!= ReturnType.Return):
for index in range(if_index + 1, len(father_node_attr)):
node = copy.deepcopy(father_node_attr[index])
if_node.orelse.append(node)
# delete original nodes
remove_num = len(father_node_attr) - if_index - 1
for _ in range(remove_num):
father_node_attr.pop()
@staticmethod
def _fold(father_node: Union[ast.FunctionDef, ast.If], attr: str):
"""Fold nodes. Iterate into body and orelse of if node."""
if not hasattr(father_node, attr) or not getattr(father_node, attr):
return
if isinstance(getattr(father_node, attr)[-1], ast.If):
RemoveReturnOutOfIf._fold(getattr(father_node, attr)[-1], 'body') # if.body
RemoveReturnOutOfIf._fold(getattr(father_node, attr)[-1], 'orelse') # if.orelse
cur_index = len(getattr(father_node, attr)) - 2 # no following nodes to fold when if node is the last one
while cur_index >= 0:
child = getattr(father_node, attr)[cur_index]
if isinstance(child, ast.If):
RemoveReturnOutOfIf._fold_return(father_node, child, cur_index, attr)
RemoveReturnOutOfIf._fold(child, 'body') # if.body
RemoveReturnOutOfIf._fold(child, 'orelse') # if.orelse
cur_index -= 1
@staticmethod
def _get_output_names(output_names: [str]):
"""Generate unique output names."""
name: str = 'output_{}'.format(len(output_names))
output_names.append(name)
return name
@staticmethod
def _move_out_return(output_names: [str], father_node: Union[ast.FunctionDef, ast.If], attr: str):
"""
Move all return node out of if nodes.
Replace all original return nodes in ast.If with ast.Assign nodes which represent 'output = return value'.
And add new ast.Return node to the end of father node.
Args:
output_names ([str]): All unique output names.
father_node (Union[ast.FunctionDef, ast.If]): Father node.
attr (str): Attribute of father nodes, can be 'body' or 'orelse'.
Raises:
RuntimeError: After iterative processing body and orelse of if nodes not all end with ast.Return.
"""
if not hasattr(father_node, attr) or not getattr(father_node, attr):
return
last_node = getattr(father_node, attr)[-1]
if isinstance(last_node, ast.If) and RemoveReturnOutOfIf._last_node_is_return(last_node) == ReturnType.Return:
# the body or orelse of last if node should be ast.Return or ast.If
if isinstance(last_node.body[-1], ast.If):
RemoveReturnOutOfIf._move_out_return(output_names, last_node, 'body')
if isinstance(last_node.orelse[-1], ast.If):
RemoveReturnOutOfIf._move_out_return(output_names, last_node, 'orelse')
# assert body and or-else all end with return
if not isinstance(last_node.body[-1], ast.Return) or not isinstance(last_node.orelse[-1], ast.Return):
raise RuntimeError("Body and orelse of if nodes not all end with ast.Return.")
output_name = RemoveReturnOutOfIf._get_output_names(output_names)
# replace body return
body_new_last_node = ast.Assign(
targets=[ast.Name(id=output_name, ctx=ast.Store())], value=last_node.body[-1].value)
last_node.body.pop()
last_node.body.append(body_new_last_node)
# replace else return
else_new_last_node = ast.Assign(
targets=[ast.Name(id=output_name, ctx=ast.Store())], value=last_node.orelse[-1].value)
last_node.orelse.pop()
last_node.orelse.append(else_new_last_node)
# add new return node
new_return_node = ast.Return(value=ast.Name(id=output_name, cts=ast.Store()))
getattr(father_node, attr).append(new_return_node)
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
"""Iterate construct node and fold following nodes into if node when condition is met."""
if node.name != "construct":
return node
RemoveReturnOutOfIf._fold(node, 'body')
output_names = []
RemoveReturnOutOfIf._move_out_return(output_names, node, 'body')
return node
def transform(self, ast_root):
"""Transform."""
ast_root = self.visit(ast_root)
ast_root = ast.fix_missing_locations(ast_root)
return ast_root

View File

@ -526,13 +526,21 @@ class Node:
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
# update targets
targets_ast = assign_ast.targets
if len(self._targets) != len(targets_ast):
if isinstance(targets_ast[0], ast.Tuple) and len(self._targets) != len(targets_ast[0].elts):
raise RuntimeError("self._targets should have the same length as targets_ast's elts")
if not isinstance(targets_ast[0], ast.Tuple) and len(self._targets) != len(targets_ast):
raise RuntimeError("self._targets should have targets_ast same length")
for i in range(0, len(self._targets)):
target = self._targets[i]
target_ast = targets_ast[i]
if not isinstance(target_ast, ast.Name):
raise TypeError("target_ast should be ast.Name, got: ", type(target_ast))
target_ast = targets_ast[0]
if isinstance(target_ast, ast.Name):
target_ast.id = target.value
elif isinstance(target_ast, ast.Tuple):
if not isinstance(target_ast.elts[i], ast.Name):
raise TypeError("target should be ast.Name, got:", type(target_ast.elts[i]))
target_ast.elts[i].id = target.value
else:
raise TypeError("target_ast should be ast.Name or ast.Tuple, got: ", type(target_ast))
target_ast.id = target.value
ast.fix_missing_locations(assign_ast)

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""Parse ast.Assign in construct function to node of SymbolTree."""
from typing import Union
import ast
import astunparse
@ -24,7 +25,7 @@ from ..symbol_tree import SymbolTree
from ..node import Node, TreeNode
from ..parser import Parser
from ..parser_register import reg_parser
from ..api.scoped_value import ScopedValue
from ..api.scoped_value import ScopedValue, ValueType
from ..symbol_tree_builder import SymbolTreeBuilder
from ..ast_helpers import AstReplacer, AstModifier
@ -59,9 +60,12 @@ class AssignParser(Parser):
tuple_elts = node.elts
tuple_values = []
for tuple_elt in tuple_elts:
if not isinstance(tuple_elt, ast.Constant):
raise RuntimeError("Only support ast.Constant as elts of ast.Tuple.")
tuple_values.append(tuple_elt.value)
if not isinstance(tuple_elt, (ast.Constant, ast.Name)):
raise RuntimeError("Only support ast.Constant or ast.Name as elts of ast.Tuple.")
if isinstance(tuple_elt, ast.Constant):
tuple_values.append(tuple_elt.value)
elif isinstance(tuple_elt, ast.Name):
tuple_values.append(tuple_elt.id)
return ScopedValue.create_variable_value(tuple(tuple_values))
@staticmethod
@ -111,7 +115,9 @@ class AssignParser(Parser):
return func.id
if isinstance(func, ast.Attribute):
return func.attr
raise RuntimeError("FuncValue is should be Name or a Attribute:", astunparse.unparse(func))
if isinstance(func, ast.Call):
return AssignParser._get_func_name(func)
raise RuntimeError("FuncValue is should be Name or a Attribute or a Call:", astunparse.unparse(func))
@staticmethod
def _get_func_scope(ast_node: ast.Call) -> str:
@ -136,7 +142,9 @@ class AssignParser(Parser):
if not isinstance(value, ast.Name):
raise RuntimeError("FuncValue is should be Name:", ast.dump(func))
return value.id
raise RuntimeError("FuncValue is should be Name or a Attribute:", ast.dump(func))
if isinstance(func, ast.Call):
return AssignParser._get_func_scope(func)
raise RuntimeError("FuncValue should be Name or a Attribute or a Call:", ast.dump(func))
@staticmethod
def _get_symbol_object(symbol_name, origin_net):
@ -206,6 +214,19 @@ class AssignParser(Parser):
return type(value), value
return type(None), None
@staticmethod
def _get_targets(all_targets: ScopedValue) -> [Union[ScopedValue, str]]:
"""Get targets from tuple or single value."""
targets: [Union[ScopedValue, str]] = []
if all_targets.type == ValueType.TupleValue:
for single_target in all_targets.value:
if not isinstance(single_target, ScopedValue) and not isinstance(single_target.value, str):
raise RuntimeError("Only support str target in tuple.")
targets.append(single_target)
else:
targets.append(all_targets)
return targets
def _update_field_in_init(self, func_scope, func_name, stree: SymbolTree, sub_tree: SymbolTree):
"""
When node is an invoking to sub-network, update value of ast.Assign of corresponding field in `__init__` method.
@ -269,7 +290,7 @@ class AssignParser(Parser):
Raises:
RuntimeError: If operator instance invoked by assign is undefined.
"""
target = AssignParser._create_scopedvalue(father_ast_node.targets[0])
targets = AssignParser._get_targets(AssignParser._create_scopedvalue(father_ast_node.targets[0]))
func_name = AssignParser._get_func_name(ast_node)
if func_name is None or func_name == "":
raise RuntimeError("function name not exist")
@ -283,7 +304,7 @@ class AssignParser(Parser):
raise RuntimeError("Operator instance undefined: '", ast.unparse(ast_node.func), "' of '",
ast.unparse(ast_node), "'")
if isinstance(op, Primitive):
return Node.create_call_buildin_op(op, father_ast_node, [target], func, call_args, call_kwargs, func_name)
return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, call_kwargs, func_name)
if isinstance(op, Cell):
is_sub_tree = self._is_subtree_cell(op)
if is_sub_tree:
@ -292,9 +313,9 @@ class AssignParser(Parser):
self._update_field_in_init(func_scope, func_name, stree, new_stree)
replacer = AstReplacer(new_stree.get_class_ast())
replacer.replace_all(new_stree.get_ori_cls_name(), new_stree.get_opt_cls_name())
return TreeNode(new_stree, father_ast_node, [target], func, call_args, call_kwargs, func_name,
return TreeNode(new_stree, father_ast_node, targets, func, call_args, call_kwargs, func_name,
new_stree.get_origin_network())
return Node.create_call_buildin_op(op, father_ast_node, [target], func, call_args, call_kwargs, func_name)
return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, call_kwargs, func_name)
raise RuntimeError("Only support Cell operator or Primitive operator, got ", type(op).__name__)
def process(self, stree: SymbolTree, node: ast.Assign):
@ -332,9 +353,9 @@ class AssignParser(Parser):
node_name = "constant_assign"
else:
node_name = "attribute_assign"
target = AssignParser._create_scopedvalue(node.targets[0])
targets = AssignParser._get_targets(AssignParser._create_scopedvalue(node.targets[0]))
call_args = [AssignParser._create_scopedvalue(value)]
node_ = Node.create_call_pass_through_method(node, [target], call_args, {}, node_name)
node_ = Node.create_call_pass_through_method(node, targets, call_args, {}, node_name)
stree.append_origin_field(node_)
elif isinstance(value, (ast.List, ast.Tuple, ast.Dict)):
# add these as callmethod node if necessary

View File

@ -23,7 +23,7 @@ from .symbol_tree import SymbolTree
from .node import TreeNode
from .parser_register import ParserRegister
from .parser import Parser
from .ast_transformers import FlattenRecursiveStmt
from .ast_transformers import FlattenRecursiveStmt, RemoveReturnOutOfIf
from .ast_helpers import AstModifier
from .ast_helpers import AstFinder
@ -55,7 +55,7 @@ class SymbolTreeBuilder:
Returns:
An instance of ast been optimized.
"""
transform_list = [FlattenRecursiveStmt()]
transform_list = [FlattenRecursiveStmt(), RemoveReturnOutOfIf()]
for transformer in transform_list:
ast_root = transformer.transform(ast_root)
return ast_root

View File

@ -0,0 +1,63 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.nn import Cell, Conv2d
from mindspore.rewrite import SymbolTree
from mindspore.ops import operations as P
class SubNet(Cell):
"""Sample cell which returns multiple features."""
def __init__(self):
"""Init."""
super().__init__()
self.conv = Conv2d(1, 10, 3)
def construct(self, x):
"""Construct."""
c1 = self.conv(x)
c2 = self.conv(c1)
c3 = self.conv(c2)
return c1, c2, c3
class NetMultiTargets(Cell):
"""Test cls for multiple targets."""
def __init__(self):
"""Init."""
super(NetMultiTargets, self).__init__()
self.conv1 = SubNet()
self.add = P.Add()
def construct(self, x):
"""Construct."""
c1, c2, c3 = self.conv1(x)
x = self.add(c1, c2)
x = self.add(x, c3)
return x
def test_multi_targets():
"""
Feature: Test multi-targets.
Description: Test multi-targets.
Expectation: Success.
"""
test_cls = NetMultiTargets()
stree = SymbolTree.create(test_cls)
node = stree.nodes()[2]
targets = node.get_targets()
assert targets[0].value == 'c1'
assert targets[1].value == 'c2'
assert targets[2].value == 'c3'

View File

@ -0,0 +1,201 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import inspect
import ast
import astunparse
from mindspore.rewrite.ast_transformers import RemoveReturnOutOfIf
class TestIf:
"""Simple test."""
@staticmethod
def construct(x):
"""construct"""
if x > 2:
return x - 2
return x
class TestIf2:
"""Test multiple if and test if in if."""
@staticmethod
def construct(x):
"""construct"""
if x > 2:
return x
x += 2
if x > 2:
if x > 2:
return x
x += 2
return x
x *= 2
return x
class TestIf3:
"""Test else."""
@staticmethod
def construct(x):
"""construct"""
if x > 2:
x -= 2
else:
return x
x -= 2
return x
class TestIf4:
"""Test elif."""
@staticmethod
def construct(x):
"""construct"""
if x > 2:
return x
x += 2
if x > 2:
x += 1
if x > 2:
x *= 2
elif x > 3:
x /= 3
else:
return x
x += 2
return x
x *= 2
return x
def test_simple_if():
"""
Feature: Test remove return from simple if.
Description: Test remove return from simple if.
Expectation: Success.
"""
ast_root: ast.Module = ast.parse(inspect.getsource(TestIf))
folder = RemoveReturnOutOfIf()
folder.transform(ast_root)
assert astunparse.unparse(ast_root) == """\n\nclass TestIf():
'Simple test.'\n
@staticmethod
def construct(x):
'construct'
if (x > 2):
output_0 = (x - 2)
else:
output_0 = x
return output_0
"""
def test_multiple_if():
"""
Feature: Test remove return from multiple if.
Description: Test remove return from multiple if.
Expectation: Success.
"""
ast_root: ast.Module = ast.parse(inspect.getsource(TestIf2))
folder = RemoveReturnOutOfIf()
folder.transform(ast_root)
assert astunparse.unparse(ast_root) == """\n\nclass TestIf2():
'Test multiple if and test if in if.'\n
@staticmethod
def construct(x):
'construct'
if (x > 2):
output_2 = x
else:
x += 2
if (x > 2):
if (x > 2):
output_0 = x
else:
x += 2
output_0 = x
output_1 = output_0
else:
x *= 2
output_1 = x
output_2 = output_1
return output_2
"""
def test_else():
"""
Feature: Test remove return in else of if node.
Description: Test remove return in else of if node.
Expectation: Success.
"""
ast_root: ast.Module = ast.parse(inspect.getsource(TestIf3))
folder = RemoveReturnOutOfIf()
folder.transform(ast_root)
assert astunparse.unparse(ast_root) == """\n\nclass TestIf3():
'Test else.'\n
@staticmethod
def construct(x):
'construct'
if (x > 2):
x -= 2
x -= 2
output_0 = x
else:
output_0 = x
return output_0
"""
def test_elif():
"""
Feature: Test remove return from elif.
Description: Test remove return from elif.
Expectation: Success.
"""
ast_root: ast.Module = ast.parse(inspect.getsource(TestIf4))
folder = RemoveReturnOutOfIf()
folder.transform(ast_root)
assert astunparse.unparse(ast_root) == """\n\nclass TestIf4():
'Test elif.'\n
@staticmethod
def construct(x):
'construct'
if (x > 2):
output_3 = x
else:
x += 2
if (x > 2):
x += 1
if (x > 2):
x *= 2
x += 2
output_1 = x
else:
if (x > 3):
x /= 3
x += 2
output_0 = x
else:
output_0 = x
output_1 = output_0
output_2 = output_1
else:
x *= 2
output_2 = x
output_3 = output_2
return output_3
"""