diff --git a/cmake/package_win.cmake b/cmake/package_win.cmake index 980a50e1ea6..9d3df25e9d3 100644 --- a/cmake/package_win.cmake +++ b/cmake/package_win.cmake @@ -206,6 +206,7 @@ install( ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/communication ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/profiler ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/compression + ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/rewrite ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/run_check DESTINATION ${INSTALL_PY_DIR} COMPONENT mindspore diff --git a/mindspore/python/mindspore/rewrite/ast_helpers/ast_modifier.py b/mindspore/python/mindspore/rewrite/ast_helpers/ast_modifier.py index e940e32438c..403b8bf8900 100644 --- a/mindspore/python/mindspore/rewrite/ast_helpers/ast_modifier.py +++ b/mindspore/python/mindspore/rewrite/ast_helpers/ast_modifier.py @@ -342,8 +342,8 @@ class AstModifier(ast.NodeTransformer): dst_ast.value = src_argument.value return if isinstance(dst_ast, ast.Name): - if src_argument.type != ValueType.NamingValue: - raise RuntimeError("src_argument.type should equal to ValueType.NamingValue") + if src_argument.type not in [ValueType.NamingValue, ValueType.StringValue]: + raise RuntimeError("src_argument.type should be ValueType.NamingValue or ValueType.StringValue.") if src_argument.scope: raise RuntimeError("src_argument.scope should be empty") dst_ast.id = src_argument.value diff --git a/mindspore/python/mindspore/rewrite/ast_transformers/__init__.py b/mindspore/python/mindspore/rewrite/ast_transformers/__init__.py index 59ce74f2ab2..ca1b73c1900 100644 --- a/mindspore/python/mindspore/rewrite/ast_transformers/__init__.py +++ b/mindspore/python/mindspore/rewrite/ast_transformers/__init__.py @@ -14,3 +14,4 @@ # ============================================================================ """Transformers for optimizing ast.""" from .flatten_recursive_stmt import FlattenRecursiveStmt +from .remove_return_out_of_if import RemoveReturnOutOfIf diff --git a/mindspore/python/mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py b/mindspore/python/mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py index 87c65cadb20..2cb2c3aa7dc 100644 --- a/mindspore/python/mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +++ b/mindspore/python/mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py @@ -53,6 +53,8 @@ class FlattenRecursiveStmt(ast.NodeTransformer): target_name = "return_value" elif isinstance(node, (ast.BinOp, ast.boolop, ast.UnaryOp)): target_name = type(node.op).__name__ + elif isinstance(node, ast.Tuple): + target_name = type(node).__name__ else: logger.warning("unhandled type of node while generating new target name: %s ", type(node)) target_name = type(node).__name__ @@ -64,21 +66,6 @@ class FlattenRecursiveStmt(ast.NodeTransformer): target_names.append(result) return result - @staticmethod - def _fill_in_original_target_names(target_names, node): - """Fill in original target names before getting unique names.""" - for function_index in range(len(node.body)): - child = node.body[function_index] - if not isinstance(child, ast.Assign): - continue - targets = child.targets - for target in targets: - if not isinstance(target, ast.Name): - raise RuntimeError("currently only support ast.Name targets") - target_name = target.id - if target_name not in target_names: - target_names.append(target_name) - @staticmethod def _create_new_assign_node(node: ast.AST, target_names) -> Tuple[str, ast.AST]: """Create new assign node to be inserted into ast.FunctionDef.""" @@ -122,6 +109,35 @@ class FlattenRecursiveStmt(ast.NodeTransformer): results.append(new_node) return results + def _fill_in_original_target_names(self, target_names, node): + """Fill in original target names before getting unique names.""" + for function_index in range(len(node.body)): + child = node.body[function_index] + if isinstance(child, (ast.Assign, ast.Expr)): + child_value = child.value + else: + child_value = child + if not self._flatten_table.get(type(child_value)): + continue + + if not isinstance(child, ast.Assign): + continue + targets = child.targets + for target in targets: + if not isinstance(target, (ast.Name, ast.Tuple)): + raise RuntimeError("currently only support ast.Name targets") + if isinstance(target, ast.Name): + target_name = target.id + if target_name not in target_names: + target_names.append(target_name) + elif isinstance(target, ast.Tuple): + for elt in target.elts: + if not isinstance(elt, ast.Name): + raise RuntimeError("currently only support ast.Name in ast.Tuple.") + target_name = elt.id + if target_name not in target_names: + target_names.append(target_name) + def visit_FunctionDef(self, node: FunctionDef) -> Any: """Traverse construct node and flatten recursive nodes.""" if node.name != "construct": diff --git a/mindspore/python/mindspore/rewrite/ast_transformers/remove_return_out_of_if.py b/mindspore/python/mindspore/rewrite/ast_transformers/remove_return_out_of_if.py new file mode 100644 index 00000000000..a2d9a366b21 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/ast_transformers/remove_return_out_of_if.py @@ -0,0 +1,225 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Fold if return.""" + +import ast +import copy +from typing import Any, Union +from enum import Enum + + +class ReturnType(Enum): + """ + ValueType represents type of nodes. + + - A `NotReturn` represents the node is not a return node. + - A `IfNotAllReturn` represents the node is an ast.If node and not all branches of it end with return. + - A `Return` represents the node is a return node or an ast.If node of which all branches of it end with return. + """ + NotReturn = 0 + IfNotAllReturn = 1 + Return = 2 + + +class RemoveReturnOutOfIf(ast.NodeTransformer): + """ + Ast optimizer for removing all returns out of if control flow. + + Example one: + def func(x): + if x == 1: + return x - 1 + x += 1 + return x + + will be optimized to + def func(x): + if x == 1: + output_0 = x - 1 + else: + x += 1 + output_0 = x + return output_0 + + Example two: + def func(x): + if x == 1: + x = 0 + elif x == 2: + x = 4 + else: + return x + x += 1 + return x + 1 + + will be optimized to + def func(x): + if x == 1: + x = 0 + x += 1 + output_1 = x + 1 + else: + if x == 2: + x = 4 + x += 1 + output_0 = x + 1 + else: + output_0 = x + output_1 = output_0 + return output_1 + """ + + @staticmethod + def _last_node_is_return(node: Union[ast.Return, ast.If]) -> ReturnType: + """ + Judge whether input node represents a return node. + Return a numeric value according to different cases: + 0: Input node is not an ast.Return or input node is an ast.If of which all branches not end with ast.Return; + 1: Input node is an ast.If and not all branches end with ast.Return; + 2: Input node is an ast.Return or input node is an ast.If of which all branches end with ast.Return. + """ + if not isinstance(node, ast.Return) and not isinstance(node, ast.If): + return ReturnType.NotReturn + if isinstance(node, ast.Return): # last node is ast.Return + return ReturnType.Return + # all branches of ast.If not end with ast.Return + if node.body and RemoveReturnOutOfIf._last_node_is_return(node.body[-1]) == ReturnType.NotReturn \ + and (not node.orelse or RemoveReturnOutOfIf._last_node_is_return(node.orelse[-1]) == + ReturnType.NotReturn): + return ReturnType.NotReturn + # all branches of ast.If end with ast.Return + if node.body and RemoveReturnOutOfIf._last_node_is_return(node.body[-1]) == ReturnType.Return \ + and node.orelse and RemoveReturnOutOfIf._last_node_is_return(node.orelse[-1]) == ReturnType.Return: + return ReturnType.Return + # not all branches of ast.If end with ast.Return + return ReturnType.IfNotAllReturn + + @staticmethod + def _fold_return(father_node: Union[ast.FunctionDef, ast.If], if_node: ast.If, if_index: int, attr: str): + """ + Fold following nodes into if node when not all branches of ast.If end with ast.Return. + + Args: + father_node (Union[ast.FunctionDef, ast.If]): Father node. + if_node (ast.If): A if node. + if_index (int): Index of the if node in body or or-else of father node. + attr (str): Attribute of father node, can be 'body' or 'orelse'. + + Raises: + RuntimeError: Father node has not input attr. + """ + if not hasattr(father_node, attr): + raise RuntimeError('Father node has not input attr', attr) + father_node_attr = getattr(father_node, attr) + if RemoveReturnOutOfIf._last_node_is_return(if_node) == ReturnType.IfNotAllReturn: + # nodes should be copied to all branches which not end with return + if if_node.body and RemoveReturnOutOfIf._last_node_is_return(if_node.body[-1]) != ReturnType.Return: + for index in range(if_index + 1, len(father_node_attr)): + node = copy.deepcopy(father_node_attr[index]) + if_node.body.append(node) + if not if_node.orelse or (if_node.orelse and RemoveReturnOutOfIf._last_node_is_return(if_node.orelse[-1]) + != ReturnType.Return): + for index in range(if_index + 1, len(father_node_attr)): + node = copy.deepcopy(father_node_attr[index]) + if_node.orelse.append(node) + # delete original nodes + remove_num = len(father_node_attr) - if_index - 1 + for _ in range(remove_num): + father_node_attr.pop() + + @staticmethod + def _fold(father_node: Union[ast.FunctionDef, ast.If], attr: str): + """Fold nodes. Iterate into body and orelse of if node.""" + if not hasattr(father_node, attr) or not getattr(father_node, attr): + return + + if isinstance(getattr(father_node, attr)[-1], ast.If): + RemoveReturnOutOfIf._fold(getattr(father_node, attr)[-1], 'body') # if.body + RemoveReturnOutOfIf._fold(getattr(father_node, attr)[-1], 'orelse') # if.orelse + + cur_index = len(getattr(father_node, attr)) - 2 # no following nodes to fold when if node is the last one + while cur_index >= 0: + child = getattr(father_node, attr)[cur_index] + if isinstance(child, ast.If): + RemoveReturnOutOfIf._fold_return(father_node, child, cur_index, attr) + RemoveReturnOutOfIf._fold(child, 'body') # if.body + RemoveReturnOutOfIf._fold(child, 'orelse') # if.orelse + cur_index -= 1 + + @staticmethod + def _get_output_names(output_names: [str]): + """Generate unique output names.""" + name: str = 'output_{}'.format(len(output_names)) + output_names.append(name) + return name + + @staticmethod + def _move_out_return(output_names: [str], father_node: Union[ast.FunctionDef, ast.If], attr: str): + """ + Move all return node out of if nodes. + Replace all original return nodes in ast.If with ast.Assign nodes which represent 'output = return value'. + And add new ast.Return node to the end of father node. + + Args: + output_names ([str]): All unique output names. + father_node (Union[ast.FunctionDef, ast.If]): Father node. + attr (str): Attribute of father nodes, can be 'body' or 'orelse'. + + Raises: + RuntimeError: After iterative processing body and orelse of if nodes not all end with ast.Return. + """ + if not hasattr(father_node, attr) or not getattr(father_node, attr): + return + + last_node = getattr(father_node, attr)[-1] + if isinstance(last_node, ast.If) and RemoveReturnOutOfIf._last_node_is_return(last_node) == ReturnType.Return: + # the body or orelse of last if node should be ast.Return or ast.If + if isinstance(last_node.body[-1], ast.If): + RemoveReturnOutOfIf._move_out_return(output_names, last_node, 'body') + if isinstance(last_node.orelse[-1], ast.If): + RemoveReturnOutOfIf._move_out_return(output_names, last_node, 'orelse') + + # assert body and or-else all end with return + if not isinstance(last_node.body[-1], ast.Return) or not isinstance(last_node.orelse[-1], ast.Return): + raise RuntimeError("Body and orelse of if nodes not all end with ast.Return.") + output_name = RemoveReturnOutOfIf._get_output_names(output_names) + # replace body return + body_new_last_node = ast.Assign( + targets=[ast.Name(id=output_name, ctx=ast.Store())], value=last_node.body[-1].value) + last_node.body.pop() + last_node.body.append(body_new_last_node) + # replace else return + else_new_last_node = ast.Assign( + targets=[ast.Name(id=output_name, ctx=ast.Store())], value=last_node.orelse[-1].value) + last_node.orelse.pop() + last_node.orelse.append(else_new_last_node) + # add new return node + new_return_node = ast.Return(value=ast.Name(id=output_name, cts=ast.Store())) + getattr(father_node, attr).append(new_return_node) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: + """Iterate construct node and fold following nodes into if node when condition is met.""" + if node.name != "construct": + return node + RemoveReturnOutOfIf._fold(node, 'body') + output_names = [] + RemoveReturnOutOfIf._move_out_return(output_names, node, 'body') + return node + + def transform(self, ast_root): + """Transform.""" + ast_root = self.visit(ast_root) + ast_root = ast.fix_missing_locations(ast_root) + return ast_root diff --git a/mindspore/python/mindspore/rewrite/node.py b/mindspore/python/mindspore/rewrite/node.py index 204eef81c55..de0c5fad3f3 100644 --- a/mindspore/python/mindspore/rewrite/node.py +++ b/mindspore/python/mindspore/rewrite/node.py @@ -526,13 +526,21 @@ class Node: raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast)) # update targets targets_ast = assign_ast.targets - if len(self._targets) != len(targets_ast): + if isinstance(targets_ast[0], ast.Tuple) and len(self._targets) != len(targets_ast[0].elts): + raise RuntimeError("self._targets should have the same length as targets_ast's elts") + if not isinstance(targets_ast[0], ast.Tuple) and len(self._targets) != len(targets_ast): raise RuntimeError("self._targets should have targets_ast same length") for i in range(0, len(self._targets)): target = self._targets[i] - target_ast = targets_ast[i] - if not isinstance(target_ast, ast.Name): - raise TypeError("target_ast should be ast.Name, got: ", type(target_ast)) + target_ast = targets_ast[0] + if isinstance(target_ast, ast.Name): + target_ast.id = target.value + elif isinstance(target_ast, ast.Tuple): + if not isinstance(target_ast.elts[i], ast.Name): + raise TypeError("target should be ast.Name, got:", type(target_ast.elts[i])) + target_ast.elts[i].id = target.value + else: + raise TypeError("target_ast should be ast.Name or ast.Tuple, got: ", type(target_ast)) target_ast.id = target.value ast.fix_missing_locations(assign_ast) diff --git a/mindspore/python/mindspore/rewrite/parsers/assign_parser.py b/mindspore/python/mindspore/rewrite/parsers/assign_parser.py index 4a2d2011b0f..d0cd04682b7 100644 --- a/mindspore/python/mindspore/rewrite/parsers/assign_parser.py +++ b/mindspore/python/mindspore/rewrite/parsers/assign_parser.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """Parse ast.Assign in construct function to node of SymbolTree.""" +from typing import Union import ast import astunparse @@ -24,7 +25,7 @@ from ..symbol_tree import SymbolTree from ..node import Node, TreeNode from ..parser import Parser from ..parser_register import reg_parser -from ..api.scoped_value import ScopedValue +from ..api.scoped_value import ScopedValue, ValueType from ..symbol_tree_builder import SymbolTreeBuilder from ..ast_helpers import AstReplacer, AstModifier @@ -59,9 +60,12 @@ class AssignParser(Parser): tuple_elts = node.elts tuple_values = [] for tuple_elt in tuple_elts: - if not isinstance(tuple_elt, ast.Constant): - raise RuntimeError("Only support ast.Constant as elts of ast.Tuple.") - tuple_values.append(tuple_elt.value) + if not isinstance(tuple_elt, (ast.Constant, ast.Name)): + raise RuntimeError("Only support ast.Constant or ast.Name as elts of ast.Tuple.") + if isinstance(tuple_elt, ast.Constant): + tuple_values.append(tuple_elt.value) + elif isinstance(tuple_elt, ast.Name): + tuple_values.append(tuple_elt.id) return ScopedValue.create_variable_value(tuple(tuple_values)) @staticmethod @@ -111,7 +115,9 @@ class AssignParser(Parser): return func.id if isinstance(func, ast.Attribute): return func.attr - raise RuntimeError("FuncValue is should be Name or a Attribute:", astunparse.unparse(func)) + if isinstance(func, ast.Call): + return AssignParser._get_func_name(func) + raise RuntimeError("FuncValue is should be Name or a Attribute or a Call:", astunparse.unparse(func)) @staticmethod def _get_func_scope(ast_node: ast.Call) -> str: @@ -136,7 +142,9 @@ class AssignParser(Parser): if not isinstance(value, ast.Name): raise RuntimeError("FuncValue is should be Name:", ast.dump(func)) return value.id - raise RuntimeError("FuncValue is should be Name or a Attribute:", ast.dump(func)) + if isinstance(func, ast.Call): + return AssignParser._get_func_scope(func) + raise RuntimeError("FuncValue should be Name or a Attribute or a Call:", ast.dump(func)) @staticmethod def _get_symbol_object(symbol_name, origin_net): @@ -206,6 +214,19 @@ class AssignParser(Parser): return type(value), value return type(None), None + @staticmethod + def _get_targets(all_targets: ScopedValue) -> [Union[ScopedValue, str]]: + """Get targets from tuple or single value.""" + targets: [Union[ScopedValue, str]] = [] + if all_targets.type == ValueType.TupleValue: + for single_target in all_targets.value: + if not isinstance(single_target, ScopedValue) and not isinstance(single_target.value, str): + raise RuntimeError("Only support str target in tuple.") + targets.append(single_target) + else: + targets.append(all_targets) + return targets + def _update_field_in_init(self, func_scope, func_name, stree: SymbolTree, sub_tree: SymbolTree): """ When node is an invoking to sub-network, update value of ast.Assign of corresponding field in `__init__` method. @@ -269,7 +290,7 @@ class AssignParser(Parser): Raises: RuntimeError: If operator instance invoked by assign is undefined. """ - target = AssignParser._create_scopedvalue(father_ast_node.targets[0]) + targets = AssignParser._get_targets(AssignParser._create_scopedvalue(father_ast_node.targets[0])) func_name = AssignParser._get_func_name(ast_node) if func_name is None or func_name == "": raise RuntimeError("function name not exist") @@ -283,7 +304,7 @@ class AssignParser(Parser): raise RuntimeError("Operator instance undefined: '", ast.unparse(ast_node.func), "' of '", ast.unparse(ast_node), "'") if isinstance(op, Primitive): - return Node.create_call_buildin_op(op, father_ast_node, [target], func, call_args, call_kwargs, func_name) + return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, call_kwargs, func_name) if isinstance(op, Cell): is_sub_tree = self._is_subtree_cell(op) if is_sub_tree: @@ -292,9 +313,9 @@ class AssignParser(Parser): self._update_field_in_init(func_scope, func_name, stree, new_stree) replacer = AstReplacer(new_stree.get_class_ast()) replacer.replace_all(new_stree.get_ori_cls_name(), new_stree.get_opt_cls_name()) - return TreeNode(new_stree, father_ast_node, [target], func, call_args, call_kwargs, func_name, + return TreeNode(new_stree, father_ast_node, targets, func, call_args, call_kwargs, func_name, new_stree.get_origin_network()) - return Node.create_call_buildin_op(op, father_ast_node, [target], func, call_args, call_kwargs, func_name) + return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, call_kwargs, func_name) raise RuntimeError("Only support Cell operator or Primitive operator, got ", type(op).__name__) def process(self, stree: SymbolTree, node: ast.Assign): @@ -332,9 +353,9 @@ class AssignParser(Parser): node_name = "constant_assign" else: node_name = "attribute_assign" - target = AssignParser._create_scopedvalue(node.targets[0]) + targets = AssignParser._get_targets(AssignParser._create_scopedvalue(node.targets[0])) call_args = [AssignParser._create_scopedvalue(value)] - node_ = Node.create_call_pass_through_method(node, [target], call_args, {}, node_name) + node_ = Node.create_call_pass_through_method(node, targets, call_args, {}, node_name) stree.append_origin_field(node_) elif isinstance(value, (ast.List, ast.Tuple, ast.Dict)): # add these as callmethod node if necessary diff --git a/mindspore/python/mindspore/rewrite/symbol_tree_builder.py b/mindspore/python/mindspore/rewrite/symbol_tree_builder.py index 6b3b3141559..92c94e6ae8b 100644 --- a/mindspore/python/mindspore/rewrite/symbol_tree_builder.py +++ b/mindspore/python/mindspore/rewrite/symbol_tree_builder.py @@ -23,7 +23,7 @@ from .symbol_tree import SymbolTree from .node import TreeNode from .parser_register import ParserRegister from .parser import Parser -from .ast_transformers import FlattenRecursiveStmt +from .ast_transformers import FlattenRecursiveStmt, RemoveReturnOutOfIf from .ast_helpers import AstModifier from .ast_helpers import AstFinder @@ -55,7 +55,7 @@ class SymbolTreeBuilder: Returns: An instance of ast been optimized. """ - transform_list = [FlattenRecursiveStmt()] + transform_list = [FlattenRecursiveStmt(), RemoveReturnOutOfIf()] for transformer in transform_list: ast_root = transformer.transform(ast_root) return ast_root diff --git a/tests/ut/python/rewrite/test_multiple_targets.py b/tests/ut/python/rewrite/test_multiple_targets.py new file mode 100644 index 00000000000..daaf250a537 --- /dev/null +++ b/tests/ut/python/rewrite/test_multiple_targets.py @@ -0,0 +1,63 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore.nn import Cell, Conv2d +from mindspore.rewrite import SymbolTree +from mindspore.ops import operations as P + + +class SubNet(Cell): + """Sample cell which returns multiple features.""" + def __init__(self): + """Init.""" + super().__init__() + self.conv = Conv2d(1, 10, 3) + + def construct(self, x): + """Construct.""" + c1 = self.conv(x) + c2 = self.conv(c1) + c3 = self.conv(c2) + return c1, c2, c3 + + +class NetMultiTargets(Cell): + """Test cls for multiple targets.""" + def __init__(self): + """Init.""" + super(NetMultiTargets, self).__init__() + self.conv1 = SubNet() + self.add = P.Add() + + def construct(self, x): + """Construct.""" + c1, c2, c3 = self.conv1(x) + x = self.add(c1, c2) + x = self.add(x, c3) + return x + + +def test_multi_targets(): + """ + Feature: Test multi-targets. + Description: Test multi-targets. + Expectation: Success. + """ + test_cls = NetMultiTargets() + stree = SymbolTree.create(test_cls) + node = stree.nodes()[2] + targets = node.get_targets() + assert targets[0].value == 'c1' + assert targets[1].value == 'c2' + assert targets[2].value == 'c3' diff --git a/tests/ut/python/rewrite/test_remove_return_if.py b/tests/ut/python/rewrite/test_remove_return_if.py new file mode 100644 index 00000000000..eedd4eab7f7 --- /dev/null +++ b/tests/ut/python/rewrite/test_remove_return_if.py @@ -0,0 +1,201 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import inspect +import ast +import astunparse + +from mindspore.rewrite.ast_transformers import RemoveReturnOutOfIf + + +class TestIf: + """Simple test.""" + @staticmethod + def construct(x): + """construct""" + if x > 2: + return x - 2 + return x + + +class TestIf2: + """Test multiple if and test if in if.""" + @staticmethod + def construct(x): + """construct""" + if x > 2: + return x + x += 2 + if x > 2: + if x > 2: + return x + x += 2 + return x + x *= 2 + return x + + +class TestIf3: + """Test else.""" + @staticmethod + def construct(x): + """construct""" + if x > 2: + x -= 2 + else: + return x + x -= 2 + return x + + +class TestIf4: + """Test elif.""" + @staticmethod + def construct(x): + """construct""" + if x > 2: + return x + x += 2 + if x > 2: + x += 1 + if x > 2: + x *= 2 + elif x > 3: + x /= 3 + else: + return x + x += 2 + return x + x *= 2 + return x + + +def test_simple_if(): + """ + Feature: Test remove return from simple if. + Description: Test remove return from simple if. + Expectation: Success. + """ + ast_root: ast.Module = ast.parse(inspect.getsource(TestIf)) + folder = RemoveReturnOutOfIf() + folder.transform(ast_root) + assert astunparse.unparse(ast_root) == """\n\nclass TestIf(): + 'Simple test.'\n + @staticmethod + def construct(x): + 'construct' + if (x > 2): + output_0 = (x - 2) + else: + output_0 = x + return output_0 +""" + + +def test_multiple_if(): + """ + Feature: Test remove return from multiple if. + Description: Test remove return from multiple if. + Expectation: Success. + """ + ast_root: ast.Module = ast.parse(inspect.getsource(TestIf2)) + folder = RemoveReturnOutOfIf() + folder.transform(ast_root) + assert astunparse.unparse(ast_root) == """\n\nclass TestIf2(): + 'Test multiple if and test if in if.'\n + @staticmethod + def construct(x): + 'construct' + if (x > 2): + output_2 = x + else: + x += 2 + if (x > 2): + if (x > 2): + output_0 = x + else: + x += 2 + output_0 = x + output_1 = output_0 + else: + x *= 2 + output_1 = x + output_2 = output_1 + return output_2 +""" + + +def test_else(): + """ + Feature: Test remove return in else of if node. + Description: Test remove return in else of if node. + Expectation: Success. + """ + ast_root: ast.Module = ast.parse(inspect.getsource(TestIf3)) + folder = RemoveReturnOutOfIf() + folder.transform(ast_root) + assert astunparse.unparse(ast_root) == """\n\nclass TestIf3(): + 'Test else.'\n + @staticmethod + def construct(x): + 'construct' + if (x > 2): + x -= 2 + x -= 2 + output_0 = x + else: + output_0 = x + return output_0 +""" + + +def test_elif(): + """ + Feature: Test remove return from elif. + Description: Test remove return from elif. + Expectation: Success. + """ + ast_root: ast.Module = ast.parse(inspect.getsource(TestIf4)) + folder = RemoveReturnOutOfIf() + folder.transform(ast_root) + assert astunparse.unparse(ast_root) == """\n\nclass TestIf4(): + 'Test elif.'\n + @staticmethod + def construct(x): + 'construct' + if (x > 2): + output_3 = x + else: + x += 2 + if (x > 2): + x += 1 + if (x > 2): + x *= 2 + x += 2 + output_1 = x + else: + if (x > 3): + x /= 3 + x += 2 + output_0 = x + else: + output_0 = x + output_1 = output_0 + output_2 = output_1 + else: + x *= 2 + output_2 = x + output_3 = output_2 + return output_3 +""" \ No newline at end of file