forked from mindspore-Ecosystem/mindspore
[MS][Rewrite]Move return out of if and support multiple targets
[MS][Rewrite]Move return out of if and support multiple targets [MS][Rewrite]Move return out of if and support multiple targets [MS][Rewrite]Move return out of if and support multiple targets [MS][Rewrite]Move return out of if and support multiple targets [MS][Rewrite]Move return out of if and support multiple targets [MS][Rewrite]Move return out of if and support multiple targets [MS][Rewrite]Move return out of if and support multiple targets [MS][Rewrite]Move return out of if and support multiple targets
This commit is contained in:
parent
3906a598ca
commit
317f59a0ae
|
@ -206,6 +206,7 @@ install(
|
||||||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/communication
|
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/communication
|
||||||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/profiler
|
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/profiler
|
||||||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/compression
|
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/compression
|
||||||
|
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/rewrite
|
||||||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/run_check
|
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/run_check
|
||||||
DESTINATION ${INSTALL_PY_DIR}
|
DESTINATION ${INSTALL_PY_DIR}
|
||||||
COMPONENT mindspore
|
COMPONENT mindspore
|
||||||
|
|
|
@ -342,8 +342,8 @@ class AstModifier(ast.NodeTransformer):
|
||||||
dst_ast.value = src_argument.value
|
dst_ast.value = src_argument.value
|
||||||
return
|
return
|
||||||
if isinstance(dst_ast, ast.Name):
|
if isinstance(dst_ast, ast.Name):
|
||||||
if src_argument.type != ValueType.NamingValue:
|
if src_argument.type not in [ValueType.NamingValue, ValueType.StringValue]:
|
||||||
raise RuntimeError("src_argument.type should equal to ValueType.NamingValue")
|
raise RuntimeError("src_argument.type should be ValueType.NamingValue or ValueType.StringValue.")
|
||||||
if src_argument.scope:
|
if src_argument.scope:
|
||||||
raise RuntimeError("src_argument.scope should be empty")
|
raise RuntimeError("src_argument.scope should be empty")
|
||||||
dst_ast.id = src_argument.value
|
dst_ast.id = src_argument.value
|
||||||
|
|
|
@ -14,3 +14,4 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Transformers for optimizing ast."""
|
"""Transformers for optimizing ast."""
|
||||||
from .flatten_recursive_stmt import FlattenRecursiveStmt
|
from .flatten_recursive_stmt import FlattenRecursiveStmt
|
||||||
|
from .remove_return_out_of_if import RemoveReturnOutOfIf
|
||||||
|
|
|
@ -53,6 +53,8 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
||||||
target_name = "return_value"
|
target_name = "return_value"
|
||||||
elif isinstance(node, (ast.BinOp, ast.boolop, ast.UnaryOp)):
|
elif isinstance(node, (ast.BinOp, ast.boolop, ast.UnaryOp)):
|
||||||
target_name = type(node.op).__name__
|
target_name = type(node.op).__name__
|
||||||
|
elif isinstance(node, ast.Tuple):
|
||||||
|
target_name = type(node).__name__
|
||||||
else:
|
else:
|
||||||
logger.warning("unhandled type of node while generating new target name: %s ", type(node))
|
logger.warning("unhandled type of node while generating new target name: %s ", type(node))
|
||||||
target_name = type(node).__name__
|
target_name = type(node).__name__
|
||||||
|
@ -64,21 +66,6 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
||||||
target_names.append(result)
|
target_names.append(result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _fill_in_original_target_names(target_names, node):
|
|
||||||
"""Fill in original target names before getting unique names."""
|
|
||||||
for function_index in range(len(node.body)):
|
|
||||||
child = node.body[function_index]
|
|
||||||
if not isinstance(child, ast.Assign):
|
|
||||||
continue
|
|
||||||
targets = child.targets
|
|
||||||
for target in targets:
|
|
||||||
if not isinstance(target, ast.Name):
|
|
||||||
raise RuntimeError("currently only support ast.Name targets")
|
|
||||||
target_name = target.id
|
|
||||||
if target_name not in target_names:
|
|
||||||
target_names.append(target_name)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_new_assign_node(node: ast.AST, target_names) -> Tuple[str, ast.AST]:
|
def _create_new_assign_node(node: ast.AST, target_names) -> Tuple[str, ast.AST]:
|
||||||
"""Create new assign node to be inserted into ast.FunctionDef."""
|
"""Create new assign node to be inserted into ast.FunctionDef."""
|
||||||
|
@ -122,6 +109,35 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
||||||
results.append(new_node)
|
results.append(new_node)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def _fill_in_original_target_names(self, target_names, node):
|
||||||
|
"""Fill in original target names before getting unique names."""
|
||||||
|
for function_index in range(len(node.body)):
|
||||||
|
child = node.body[function_index]
|
||||||
|
if isinstance(child, (ast.Assign, ast.Expr)):
|
||||||
|
child_value = child.value
|
||||||
|
else:
|
||||||
|
child_value = child
|
||||||
|
if not self._flatten_table.get(type(child_value)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not isinstance(child, ast.Assign):
|
||||||
|
continue
|
||||||
|
targets = child.targets
|
||||||
|
for target in targets:
|
||||||
|
if not isinstance(target, (ast.Name, ast.Tuple)):
|
||||||
|
raise RuntimeError("currently only support ast.Name targets")
|
||||||
|
if isinstance(target, ast.Name):
|
||||||
|
target_name = target.id
|
||||||
|
if target_name not in target_names:
|
||||||
|
target_names.append(target_name)
|
||||||
|
elif isinstance(target, ast.Tuple):
|
||||||
|
for elt in target.elts:
|
||||||
|
if not isinstance(elt, ast.Name):
|
||||||
|
raise RuntimeError("currently only support ast.Name in ast.Tuple.")
|
||||||
|
target_name = elt.id
|
||||||
|
if target_name not in target_names:
|
||||||
|
target_names.append(target_name)
|
||||||
|
|
||||||
def visit_FunctionDef(self, node: FunctionDef) -> Any:
|
def visit_FunctionDef(self, node: FunctionDef) -> Any:
|
||||||
"""Traverse construct node and flatten recursive nodes."""
|
"""Traverse construct node and flatten recursive nodes."""
|
||||||
if node.name != "construct":
|
if node.name != "construct":
|
||||||
|
|
|
@ -0,0 +1,225 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Fold if return."""
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import copy
|
||||||
|
from typing import Any, Union
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class ReturnType(Enum):
|
||||||
|
"""
|
||||||
|
ValueType represents type of nodes.
|
||||||
|
|
||||||
|
- A `NotReturn` represents the node is not a return node.
|
||||||
|
- A `IfNotAllReturn` represents the node is an ast.If node and not all branches of it end with return.
|
||||||
|
- A `Return` represents the node is a return node or an ast.If node of which all branches of it end with return.
|
||||||
|
"""
|
||||||
|
NotReturn = 0
|
||||||
|
IfNotAllReturn = 1
|
||||||
|
Return = 2
|
||||||
|
|
||||||
|
|
||||||
|
class RemoveReturnOutOfIf(ast.NodeTransformer):
|
||||||
|
"""
|
||||||
|
Ast optimizer for removing all returns out of if control flow.
|
||||||
|
|
||||||
|
Example one:
|
||||||
|
def func(x):
|
||||||
|
if x == 1:
|
||||||
|
return x - 1
|
||||||
|
x += 1
|
||||||
|
return x
|
||||||
|
|
||||||
|
will be optimized to
|
||||||
|
def func(x):
|
||||||
|
if x == 1:
|
||||||
|
output_0 = x - 1
|
||||||
|
else:
|
||||||
|
x += 1
|
||||||
|
output_0 = x
|
||||||
|
return output_0
|
||||||
|
|
||||||
|
Example two:
|
||||||
|
def func(x):
|
||||||
|
if x == 1:
|
||||||
|
x = 0
|
||||||
|
elif x == 2:
|
||||||
|
x = 4
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
x += 1
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
will be optimized to
|
||||||
|
def func(x):
|
||||||
|
if x == 1:
|
||||||
|
x = 0
|
||||||
|
x += 1
|
||||||
|
output_1 = x + 1
|
||||||
|
else:
|
||||||
|
if x == 2:
|
||||||
|
x = 4
|
||||||
|
x += 1
|
||||||
|
output_0 = x + 1
|
||||||
|
else:
|
||||||
|
output_0 = x
|
||||||
|
output_1 = output_0
|
||||||
|
return output_1
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _last_node_is_return(node: Union[ast.Return, ast.If]) -> ReturnType:
|
||||||
|
"""
|
||||||
|
Judge whether input node represents a return node.
|
||||||
|
Return a numeric value according to different cases:
|
||||||
|
0: Input node is not an ast.Return or input node is an ast.If of which all branches not end with ast.Return;
|
||||||
|
1: Input node is an ast.If and not all branches end with ast.Return;
|
||||||
|
2: Input node is an ast.Return or input node is an ast.If of which all branches end with ast.Return.
|
||||||
|
"""
|
||||||
|
if not isinstance(node, ast.Return) and not isinstance(node, ast.If):
|
||||||
|
return ReturnType.NotReturn
|
||||||
|
if isinstance(node, ast.Return): # last node is ast.Return
|
||||||
|
return ReturnType.Return
|
||||||
|
# all branches of ast.If not end with ast.Return
|
||||||
|
if node.body and RemoveReturnOutOfIf._last_node_is_return(node.body[-1]) == ReturnType.NotReturn \
|
||||||
|
and (not node.orelse or RemoveReturnOutOfIf._last_node_is_return(node.orelse[-1]) ==
|
||||||
|
ReturnType.NotReturn):
|
||||||
|
return ReturnType.NotReturn
|
||||||
|
# all branches of ast.If end with ast.Return
|
||||||
|
if node.body and RemoveReturnOutOfIf._last_node_is_return(node.body[-1]) == ReturnType.Return \
|
||||||
|
and node.orelse and RemoveReturnOutOfIf._last_node_is_return(node.orelse[-1]) == ReturnType.Return:
|
||||||
|
return ReturnType.Return
|
||||||
|
# not all branches of ast.If end with ast.Return
|
||||||
|
return ReturnType.IfNotAllReturn
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fold_return(father_node: Union[ast.FunctionDef, ast.If], if_node: ast.If, if_index: int, attr: str):
|
||||||
|
"""
|
||||||
|
Fold following nodes into if node when not all branches of ast.If end with ast.Return.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
father_node (Union[ast.FunctionDef, ast.If]): Father node.
|
||||||
|
if_node (ast.If): A if node.
|
||||||
|
if_index (int): Index of the if node in body or or-else of father node.
|
||||||
|
attr (str): Attribute of father node, can be 'body' or 'orelse'.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: Father node has not input attr.
|
||||||
|
"""
|
||||||
|
if not hasattr(father_node, attr):
|
||||||
|
raise RuntimeError('Father node has not input attr', attr)
|
||||||
|
father_node_attr = getattr(father_node, attr)
|
||||||
|
if RemoveReturnOutOfIf._last_node_is_return(if_node) == ReturnType.IfNotAllReturn:
|
||||||
|
# nodes should be copied to all branches which not end with return
|
||||||
|
if if_node.body and RemoveReturnOutOfIf._last_node_is_return(if_node.body[-1]) != ReturnType.Return:
|
||||||
|
for index in range(if_index + 1, len(father_node_attr)):
|
||||||
|
node = copy.deepcopy(father_node_attr[index])
|
||||||
|
if_node.body.append(node)
|
||||||
|
if not if_node.orelse or (if_node.orelse and RemoveReturnOutOfIf._last_node_is_return(if_node.orelse[-1])
|
||||||
|
!= ReturnType.Return):
|
||||||
|
for index in range(if_index + 1, len(father_node_attr)):
|
||||||
|
node = copy.deepcopy(father_node_attr[index])
|
||||||
|
if_node.orelse.append(node)
|
||||||
|
# delete original nodes
|
||||||
|
remove_num = len(father_node_attr) - if_index - 1
|
||||||
|
for _ in range(remove_num):
|
||||||
|
father_node_attr.pop()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fold(father_node: Union[ast.FunctionDef, ast.If], attr: str):
|
||||||
|
"""Fold nodes. Iterate into body and orelse of if node."""
|
||||||
|
if not hasattr(father_node, attr) or not getattr(father_node, attr):
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(getattr(father_node, attr)[-1], ast.If):
|
||||||
|
RemoveReturnOutOfIf._fold(getattr(father_node, attr)[-1], 'body') # if.body
|
||||||
|
RemoveReturnOutOfIf._fold(getattr(father_node, attr)[-1], 'orelse') # if.orelse
|
||||||
|
|
||||||
|
cur_index = len(getattr(father_node, attr)) - 2 # no following nodes to fold when if node is the last one
|
||||||
|
while cur_index >= 0:
|
||||||
|
child = getattr(father_node, attr)[cur_index]
|
||||||
|
if isinstance(child, ast.If):
|
||||||
|
RemoveReturnOutOfIf._fold_return(father_node, child, cur_index, attr)
|
||||||
|
RemoveReturnOutOfIf._fold(child, 'body') # if.body
|
||||||
|
RemoveReturnOutOfIf._fold(child, 'orelse') # if.orelse
|
||||||
|
cur_index -= 1
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_output_names(output_names: [str]):
|
||||||
|
"""Generate unique output names."""
|
||||||
|
name: str = 'output_{}'.format(len(output_names))
|
||||||
|
output_names.append(name)
|
||||||
|
return name
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _move_out_return(output_names: [str], father_node: Union[ast.FunctionDef, ast.If], attr: str):
|
||||||
|
"""
|
||||||
|
Move all return node out of if nodes.
|
||||||
|
Replace all original return nodes in ast.If with ast.Assign nodes which represent 'output = return value'.
|
||||||
|
And add new ast.Return node to the end of father node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_names ([str]): All unique output names.
|
||||||
|
father_node (Union[ast.FunctionDef, ast.If]): Father node.
|
||||||
|
attr (str): Attribute of father nodes, can be 'body' or 'orelse'.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: After iterative processing body and orelse of if nodes not all end with ast.Return.
|
||||||
|
"""
|
||||||
|
if not hasattr(father_node, attr) or not getattr(father_node, attr):
|
||||||
|
return
|
||||||
|
|
||||||
|
last_node = getattr(father_node, attr)[-1]
|
||||||
|
if isinstance(last_node, ast.If) and RemoveReturnOutOfIf._last_node_is_return(last_node) == ReturnType.Return:
|
||||||
|
# the body or orelse of last if node should be ast.Return or ast.If
|
||||||
|
if isinstance(last_node.body[-1], ast.If):
|
||||||
|
RemoveReturnOutOfIf._move_out_return(output_names, last_node, 'body')
|
||||||
|
if isinstance(last_node.orelse[-1], ast.If):
|
||||||
|
RemoveReturnOutOfIf._move_out_return(output_names, last_node, 'orelse')
|
||||||
|
|
||||||
|
# assert body and or-else all end with return
|
||||||
|
if not isinstance(last_node.body[-1], ast.Return) or not isinstance(last_node.orelse[-1], ast.Return):
|
||||||
|
raise RuntimeError("Body and orelse of if nodes not all end with ast.Return.")
|
||||||
|
output_name = RemoveReturnOutOfIf._get_output_names(output_names)
|
||||||
|
# replace body return
|
||||||
|
body_new_last_node = ast.Assign(
|
||||||
|
targets=[ast.Name(id=output_name, ctx=ast.Store())], value=last_node.body[-1].value)
|
||||||
|
last_node.body.pop()
|
||||||
|
last_node.body.append(body_new_last_node)
|
||||||
|
# replace else return
|
||||||
|
else_new_last_node = ast.Assign(
|
||||||
|
targets=[ast.Name(id=output_name, ctx=ast.Store())], value=last_node.orelse[-1].value)
|
||||||
|
last_node.orelse.pop()
|
||||||
|
last_node.orelse.append(else_new_last_node)
|
||||||
|
# add new return node
|
||||||
|
new_return_node = ast.Return(value=ast.Name(id=output_name, cts=ast.Store()))
|
||||||
|
getattr(father_node, attr).append(new_return_node)
|
||||||
|
|
||||||
|
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||||
|
"""Iterate construct node and fold following nodes into if node when condition is met."""
|
||||||
|
if node.name != "construct":
|
||||||
|
return node
|
||||||
|
RemoveReturnOutOfIf._fold(node, 'body')
|
||||||
|
output_names = []
|
||||||
|
RemoveReturnOutOfIf._move_out_return(output_names, node, 'body')
|
||||||
|
return node
|
||||||
|
|
||||||
|
def transform(self, ast_root):
|
||||||
|
"""Transform."""
|
||||||
|
ast_root = self.visit(ast_root)
|
||||||
|
ast_root = ast.fix_missing_locations(ast_root)
|
||||||
|
return ast_root
|
|
@ -526,13 +526,21 @@ class Node:
|
||||||
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
|
raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast))
|
||||||
# update targets
|
# update targets
|
||||||
targets_ast = assign_ast.targets
|
targets_ast = assign_ast.targets
|
||||||
if len(self._targets) != len(targets_ast):
|
if isinstance(targets_ast[0], ast.Tuple) and len(self._targets) != len(targets_ast[0].elts):
|
||||||
|
raise RuntimeError("self._targets should have the same length as targets_ast's elts")
|
||||||
|
if not isinstance(targets_ast[0], ast.Tuple) and len(self._targets) != len(targets_ast):
|
||||||
raise RuntimeError("self._targets should have targets_ast same length")
|
raise RuntimeError("self._targets should have targets_ast same length")
|
||||||
for i in range(0, len(self._targets)):
|
for i in range(0, len(self._targets)):
|
||||||
target = self._targets[i]
|
target = self._targets[i]
|
||||||
target_ast = targets_ast[i]
|
target_ast = targets_ast[0]
|
||||||
if not isinstance(target_ast, ast.Name):
|
if isinstance(target_ast, ast.Name):
|
||||||
raise TypeError("target_ast should be ast.Name, got: ", type(target_ast))
|
target_ast.id = target.value
|
||||||
|
elif isinstance(target_ast, ast.Tuple):
|
||||||
|
if not isinstance(target_ast.elts[i], ast.Name):
|
||||||
|
raise TypeError("target should be ast.Name, got:", type(target_ast.elts[i]))
|
||||||
|
target_ast.elts[i].id = target.value
|
||||||
|
else:
|
||||||
|
raise TypeError("target_ast should be ast.Name or ast.Tuple, got: ", type(target_ast))
|
||||||
target_ast.id = target.value
|
target_ast.id = target.value
|
||||||
ast.fix_missing_locations(assign_ast)
|
ast.fix_missing_locations(assign_ast)
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Parse ast.Assign in construct function to node of SymbolTree."""
|
"""Parse ast.Assign in construct function to node of SymbolTree."""
|
||||||
|
from typing import Union
|
||||||
import ast
|
import ast
|
||||||
import astunparse
|
import astunparse
|
||||||
|
|
||||||
|
@ -24,7 +25,7 @@ from ..symbol_tree import SymbolTree
|
||||||
from ..node import Node, TreeNode
|
from ..node import Node, TreeNode
|
||||||
from ..parser import Parser
|
from ..parser import Parser
|
||||||
from ..parser_register import reg_parser
|
from ..parser_register import reg_parser
|
||||||
from ..api.scoped_value import ScopedValue
|
from ..api.scoped_value import ScopedValue, ValueType
|
||||||
from ..symbol_tree_builder import SymbolTreeBuilder
|
from ..symbol_tree_builder import SymbolTreeBuilder
|
||||||
from ..ast_helpers import AstReplacer, AstModifier
|
from ..ast_helpers import AstReplacer, AstModifier
|
||||||
|
|
||||||
|
@ -59,9 +60,12 @@ class AssignParser(Parser):
|
||||||
tuple_elts = node.elts
|
tuple_elts = node.elts
|
||||||
tuple_values = []
|
tuple_values = []
|
||||||
for tuple_elt in tuple_elts:
|
for tuple_elt in tuple_elts:
|
||||||
if not isinstance(tuple_elt, ast.Constant):
|
if not isinstance(tuple_elt, (ast.Constant, ast.Name)):
|
||||||
raise RuntimeError("Only support ast.Constant as elts of ast.Tuple.")
|
raise RuntimeError("Only support ast.Constant or ast.Name as elts of ast.Tuple.")
|
||||||
tuple_values.append(tuple_elt.value)
|
if isinstance(tuple_elt, ast.Constant):
|
||||||
|
tuple_values.append(tuple_elt.value)
|
||||||
|
elif isinstance(tuple_elt, ast.Name):
|
||||||
|
tuple_values.append(tuple_elt.id)
|
||||||
return ScopedValue.create_variable_value(tuple(tuple_values))
|
return ScopedValue.create_variable_value(tuple(tuple_values))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -111,7 +115,9 @@ class AssignParser(Parser):
|
||||||
return func.id
|
return func.id
|
||||||
if isinstance(func, ast.Attribute):
|
if isinstance(func, ast.Attribute):
|
||||||
return func.attr
|
return func.attr
|
||||||
raise RuntimeError("FuncValue is should be Name or a Attribute:", astunparse.unparse(func))
|
if isinstance(func, ast.Call):
|
||||||
|
return AssignParser._get_func_name(func)
|
||||||
|
raise RuntimeError("FuncValue is should be Name or a Attribute or a Call:", astunparse.unparse(func))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_func_scope(ast_node: ast.Call) -> str:
|
def _get_func_scope(ast_node: ast.Call) -> str:
|
||||||
|
@ -136,7 +142,9 @@ class AssignParser(Parser):
|
||||||
if not isinstance(value, ast.Name):
|
if not isinstance(value, ast.Name):
|
||||||
raise RuntimeError("FuncValue is should be Name:", ast.dump(func))
|
raise RuntimeError("FuncValue is should be Name:", ast.dump(func))
|
||||||
return value.id
|
return value.id
|
||||||
raise RuntimeError("FuncValue is should be Name or a Attribute:", ast.dump(func))
|
if isinstance(func, ast.Call):
|
||||||
|
return AssignParser._get_func_scope(func)
|
||||||
|
raise RuntimeError("FuncValue should be Name or a Attribute or a Call:", ast.dump(func))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_symbol_object(symbol_name, origin_net):
|
def _get_symbol_object(symbol_name, origin_net):
|
||||||
|
@ -206,6 +214,19 @@ class AssignParser(Parser):
|
||||||
return type(value), value
|
return type(value), value
|
||||||
return type(None), None
|
return type(None), None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_targets(all_targets: ScopedValue) -> [Union[ScopedValue, str]]:
|
||||||
|
"""Get targets from tuple or single value."""
|
||||||
|
targets: [Union[ScopedValue, str]] = []
|
||||||
|
if all_targets.type == ValueType.TupleValue:
|
||||||
|
for single_target in all_targets.value:
|
||||||
|
if not isinstance(single_target, ScopedValue) and not isinstance(single_target.value, str):
|
||||||
|
raise RuntimeError("Only support str target in tuple.")
|
||||||
|
targets.append(single_target)
|
||||||
|
else:
|
||||||
|
targets.append(all_targets)
|
||||||
|
return targets
|
||||||
|
|
||||||
def _update_field_in_init(self, func_scope, func_name, stree: SymbolTree, sub_tree: SymbolTree):
|
def _update_field_in_init(self, func_scope, func_name, stree: SymbolTree, sub_tree: SymbolTree):
|
||||||
"""
|
"""
|
||||||
When node is an invoking to sub-network, update value of ast.Assign of corresponding field in `__init__` method.
|
When node is an invoking to sub-network, update value of ast.Assign of corresponding field in `__init__` method.
|
||||||
|
@ -269,7 +290,7 @@ class AssignParser(Parser):
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If operator instance invoked by assign is undefined.
|
RuntimeError: If operator instance invoked by assign is undefined.
|
||||||
"""
|
"""
|
||||||
target = AssignParser._create_scopedvalue(father_ast_node.targets[0])
|
targets = AssignParser._get_targets(AssignParser._create_scopedvalue(father_ast_node.targets[0]))
|
||||||
func_name = AssignParser._get_func_name(ast_node)
|
func_name = AssignParser._get_func_name(ast_node)
|
||||||
if func_name is None or func_name == "":
|
if func_name is None or func_name == "":
|
||||||
raise RuntimeError("function name not exist")
|
raise RuntimeError("function name not exist")
|
||||||
|
@ -283,7 +304,7 @@ class AssignParser(Parser):
|
||||||
raise RuntimeError("Operator instance undefined: '", ast.unparse(ast_node.func), "' of '",
|
raise RuntimeError("Operator instance undefined: '", ast.unparse(ast_node.func), "' of '",
|
||||||
ast.unparse(ast_node), "'")
|
ast.unparse(ast_node), "'")
|
||||||
if isinstance(op, Primitive):
|
if isinstance(op, Primitive):
|
||||||
return Node.create_call_buildin_op(op, father_ast_node, [target], func, call_args, call_kwargs, func_name)
|
return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, call_kwargs, func_name)
|
||||||
if isinstance(op, Cell):
|
if isinstance(op, Cell):
|
||||||
is_sub_tree = self._is_subtree_cell(op)
|
is_sub_tree = self._is_subtree_cell(op)
|
||||||
if is_sub_tree:
|
if is_sub_tree:
|
||||||
|
@ -292,9 +313,9 @@ class AssignParser(Parser):
|
||||||
self._update_field_in_init(func_scope, func_name, stree, new_stree)
|
self._update_field_in_init(func_scope, func_name, stree, new_stree)
|
||||||
replacer = AstReplacer(new_stree.get_class_ast())
|
replacer = AstReplacer(new_stree.get_class_ast())
|
||||||
replacer.replace_all(new_stree.get_ori_cls_name(), new_stree.get_opt_cls_name())
|
replacer.replace_all(new_stree.get_ori_cls_name(), new_stree.get_opt_cls_name())
|
||||||
return TreeNode(new_stree, father_ast_node, [target], func, call_args, call_kwargs, func_name,
|
return TreeNode(new_stree, father_ast_node, targets, func, call_args, call_kwargs, func_name,
|
||||||
new_stree.get_origin_network())
|
new_stree.get_origin_network())
|
||||||
return Node.create_call_buildin_op(op, father_ast_node, [target], func, call_args, call_kwargs, func_name)
|
return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, call_kwargs, func_name)
|
||||||
raise RuntimeError("Only support Cell operator or Primitive operator, got ", type(op).__name__)
|
raise RuntimeError("Only support Cell operator or Primitive operator, got ", type(op).__name__)
|
||||||
|
|
||||||
def process(self, stree: SymbolTree, node: ast.Assign):
|
def process(self, stree: SymbolTree, node: ast.Assign):
|
||||||
|
@ -332,9 +353,9 @@ class AssignParser(Parser):
|
||||||
node_name = "constant_assign"
|
node_name = "constant_assign"
|
||||||
else:
|
else:
|
||||||
node_name = "attribute_assign"
|
node_name = "attribute_assign"
|
||||||
target = AssignParser._create_scopedvalue(node.targets[0])
|
targets = AssignParser._get_targets(AssignParser._create_scopedvalue(node.targets[0]))
|
||||||
call_args = [AssignParser._create_scopedvalue(value)]
|
call_args = [AssignParser._create_scopedvalue(value)]
|
||||||
node_ = Node.create_call_pass_through_method(node, [target], call_args, {}, node_name)
|
node_ = Node.create_call_pass_through_method(node, targets, call_args, {}, node_name)
|
||||||
stree.append_origin_field(node_)
|
stree.append_origin_field(node_)
|
||||||
elif isinstance(value, (ast.List, ast.Tuple, ast.Dict)):
|
elif isinstance(value, (ast.List, ast.Tuple, ast.Dict)):
|
||||||
# add these as callmethod node if necessary
|
# add these as callmethod node if necessary
|
||||||
|
|
|
@ -23,7 +23,7 @@ from .symbol_tree import SymbolTree
|
||||||
from .node import TreeNode
|
from .node import TreeNode
|
||||||
from .parser_register import ParserRegister
|
from .parser_register import ParserRegister
|
||||||
from .parser import Parser
|
from .parser import Parser
|
||||||
from .ast_transformers import FlattenRecursiveStmt
|
from .ast_transformers import FlattenRecursiveStmt, RemoveReturnOutOfIf
|
||||||
from .ast_helpers import AstModifier
|
from .ast_helpers import AstModifier
|
||||||
from .ast_helpers import AstFinder
|
from .ast_helpers import AstFinder
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ class SymbolTreeBuilder:
|
||||||
Returns:
|
Returns:
|
||||||
An instance of ast been optimized.
|
An instance of ast been optimized.
|
||||||
"""
|
"""
|
||||||
transform_list = [FlattenRecursiveStmt()]
|
transform_list = [FlattenRecursiveStmt(), RemoveReturnOutOfIf()]
|
||||||
for transformer in transform_list:
|
for transformer in transform_list:
|
||||||
ast_root = transformer.transform(ast_root)
|
ast_root = transformer.transform(ast_root)
|
||||||
return ast_root
|
return ast_root
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
from mindspore.nn import Cell, Conv2d
|
||||||
|
from mindspore.rewrite import SymbolTree
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
|
class SubNet(Cell):
|
||||||
|
"""Sample cell which returns multiple features."""
|
||||||
|
def __init__(self):
|
||||||
|
"""Init."""
|
||||||
|
super().__init__()
|
||||||
|
self.conv = Conv2d(1, 10, 3)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""Construct."""
|
||||||
|
c1 = self.conv(x)
|
||||||
|
c2 = self.conv(c1)
|
||||||
|
c3 = self.conv(c2)
|
||||||
|
return c1, c2, c3
|
||||||
|
|
||||||
|
|
||||||
|
class NetMultiTargets(Cell):
|
||||||
|
"""Test cls for multiple targets."""
|
||||||
|
def __init__(self):
|
||||||
|
"""Init."""
|
||||||
|
super(NetMultiTargets, self).__init__()
|
||||||
|
self.conv1 = SubNet()
|
||||||
|
self.add = P.Add()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
"""Construct."""
|
||||||
|
c1, c2, c3 = self.conv1(x)
|
||||||
|
x = self.add(c1, c2)
|
||||||
|
x = self.add(x, c3)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_targets():
|
||||||
|
"""
|
||||||
|
Feature: Test multi-targets.
|
||||||
|
Description: Test multi-targets.
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
test_cls = NetMultiTargets()
|
||||||
|
stree = SymbolTree.create(test_cls)
|
||||||
|
node = stree.nodes()[2]
|
||||||
|
targets = node.get_targets()
|
||||||
|
assert targets[0].value == 'c1'
|
||||||
|
assert targets[1].value == 'c2'
|
||||||
|
assert targets[2].value == 'c3'
|
|
@ -0,0 +1,201 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import inspect
|
||||||
|
import ast
|
||||||
|
import astunparse
|
||||||
|
|
||||||
|
from mindspore.rewrite.ast_transformers import RemoveReturnOutOfIf
|
||||||
|
|
||||||
|
|
||||||
|
class TestIf:
|
||||||
|
"""Simple test."""
|
||||||
|
@staticmethod
|
||||||
|
def construct(x):
|
||||||
|
"""construct"""
|
||||||
|
if x > 2:
|
||||||
|
return x - 2
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TestIf2:
|
||||||
|
"""Test multiple if and test if in if."""
|
||||||
|
@staticmethod
|
||||||
|
def construct(x):
|
||||||
|
"""construct"""
|
||||||
|
if x > 2:
|
||||||
|
return x
|
||||||
|
x += 2
|
||||||
|
if x > 2:
|
||||||
|
if x > 2:
|
||||||
|
return x
|
||||||
|
x += 2
|
||||||
|
return x
|
||||||
|
x *= 2
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TestIf3:
|
||||||
|
"""Test else."""
|
||||||
|
@staticmethod
|
||||||
|
def construct(x):
|
||||||
|
"""construct"""
|
||||||
|
if x > 2:
|
||||||
|
x -= 2
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
x -= 2
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TestIf4:
|
||||||
|
"""Test elif."""
|
||||||
|
@staticmethod
|
||||||
|
def construct(x):
|
||||||
|
"""construct"""
|
||||||
|
if x > 2:
|
||||||
|
return x
|
||||||
|
x += 2
|
||||||
|
if x > 2:
|
||||||
|
x += 1
|
||||||
|
if x > 2:
|
||||||
|
x *= 2
|
||||||
|
elif x > 3:
|
||||||
|
x /= 3
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
x += 2
|
||||||
|
return x
|
||||||
|
x *= 2
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_if():
|
||||||
|
"""
|
||||||
|
Feature: Test remove return from simple if.
|
||||||
|
Description: Test remove return from simple if.
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
ast_root: ast.Module = ast.parse(inspect.getsource(TestIf))
|
||||||
|
folder = RemoveReturnOutOfIf()
|
||||||
|
folder.transform(ast_root)
|
||||||
|
assert astunparse.unparse(ast_root) == """\n\nclass TestIf():
|
||||||
|
'Simple test.'\n
|
||||||
|
@staticmethod
|
||||||
|
def construct(x):
|
||||||
|
'construct'
|
||||||
|
if (x > 2):
|
||||||
|
output_0 = (x - 2)
|
||||||
|
else:
|
||||||
|
output_0 = x
|
||||||
|
return output_0
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_if():
|
||||||
|
"""
|
||||||
|
Feature: Test remove return from multiple if.
|
||||||
|
Description: Test remove return from multiple if.
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
ast_root: ast.Module = ast.parse(inspect.getsource(TestIf2))
|
||||||
|
folder = RemoveReturnOutOfIf()
|
||||||
|
folder.transform(ast_root)
|
||||||
|
assert astunparse.unparse(ast_root) == """\n\nclass TestIf2():
|
||||||
|
'Test multiple if and test if in if.'\n
|
||||||
|
@staticmethod
|
||||||
|
def construct(x):
|
||||||
|
'construct'
|
||||||
|
if (x > 2):
|
||||||
|
output_2 = x
|
||||||
|
else:
|
||||||
|
x += 2
|
||||||
|
if (x > 2):
|
||||||
|
if (x > 2):
|
||||||
|
output_0 = x
|
||||||
|
else:
|
||||||
|
x += 2
|
||||||
|
output_0 = x
|
||||||
|
output_1 = output_0
|
||||||
|
else:
|
||||||
|
x *= 2
|
||||||
|
output_1 = x
|
||||||
|
output_2 = output_1
|
||||||
|
return output_2
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_else():
|
||||||
|
"""
|
||||||
|
Feature: Test remove return in else of if node.
|
||||||
|
Description: Test remove return in else of if node.
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
ast_root: ast.Module = ast.parse(inspect.getsource(TestIf3))
|
||||||
|
folder = RemoveReturnOutOfIf()
|
||||||
|
folder.transform(ast_root)
|
||||||
|
assert astunparse.unparse(ast_root) == """\n\nclass TestIf3():
|
||||||
|
'Test else.'\n
|
||||||
|
@staticmethod
|
||||||
|
def construct(x):
|
||||||
|
'construct'
|
||||||
|
if (x > 2):
|
||||||
|
x -= 2
|
||||||
|
x -= 2
|
||||||
|
output_0 = x
|
||||||
|
else:
|
||||||
|
output_0 = x
|
||||||
|
return output_0
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_elif():
|
||||||
|
"""
|
||||||
|
Feature: Test remove return from elif.
|
||||||
|
Description: Test remove return from elif.
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
ast_root: ast.Module = ast.parse(inspect.getsource(TestIf4))
|
||||||
|
folder = RemoveReturnOutOfIf()
|
||||||
|
folder.transform(ast_root)
|
||||||
|
assert astunparse.unparse(ast_root) == """\n\nclass TestIf4():
|
||||||
|
'Test elif.'\n
|
||||||
|
@staticmethod
|
||||||
|
def construct(x):
|
||||||
|
'construct'
|
||||||
|
if (x > 2):
|
||||||
|
output_3 = x
|
||||||
|
else:
|
||||||
|
x += 2
|
||||||
|
if (x > 2):
|
||||||
|
x += 1
|
||||||
|
if (x > 2):
|
||||||
|
x *= 2
|
||||||
|
x += 2
|
||||||
|
output_1 = x
|
||||||
|
else:
|
||||||
|
if (x > 3):
|
||||||
|
x /= 3
|
||||||
|
x += 2
|
||||||
|
output_0 = x
|
||||||
|
else:
|
||||||
|
output_0 = x
|
||||||
|
output_1 = output_0
|
||||||
|
output_2 = output_1
|
||||||
|
else:
|
||||||
|
x *= 2
|
||||||
|
output_2 = x
|
||||||
|
output_3 = output_2
|
||||||
|
return output_3
|
||||||
|
"""
|
Loading…
Reference in New Issue