diff --git a/.jenkins/check/config/filter_pylint.txt b/.jenkins/check/config/filter_pylint.txt index fc8bc37bf63..dde8a938513 100644 --- a/.jenkins/check/config/filter_pylint.txt +++ b/.jenkins/check/config/filter_pylint.txt @@ -38,6 +38,11 @@ "mindspore/mindspore/python/mindspore/train/serialization.py" "protected-access" "mindspore/mindspore/python/mindspore/train/model.py" "protected-access" "mindspore/mindspore/python/mindspore/log.py" "protected-access" +"mindspore/mindspore/python/mindspore/rewrite/api/node.py" "protected-access" +"mindspore/mindspore/python/mindspore/rewrite/node.py" "protected-access" +"mindspore/mindspore/python/mindspore/rewrite/parser_register.py" "protected-access" +"mindspore/mindspore/python/mindspore/rewrite/api/pattern_engine.py" "protected-access" +"mindspore/mindspore/python/mindspore/rewrite/symbol_tree.py" "inconsistent-return-statements" "mindspore/model_zoo/official/cv" "missing-docstring" "mindspore/model_zoo/official/cv" "c-extension-no-member" "mindspore/model_zoo/official/nlp/bert_thor/src/bert_model.py" "redefined-outer-name" @@ -101,6 +106,10 @@ "mindspore/tests/ut/python/pynative_mode" "unused-variable" "mindspore/tests/ut/python/pynative_mode/test_stop_gradient.py" "redefined-outer-name" "mindspore/tests/ut/python/pynative_mode/test_stop_gradient.py" "super-init-not-called" +"mindspore/tests/ut/python/rewrite/test_flatten_recursive_stmt.py" "consider-using-ternary" +"mindspore/tests/ut/python/rewrite/test_node.py" "syntax-error" +"mindspore/tests/ut/python/rewrite/test_node.py" "protected-access" +"mindspore/tests/ut/python/rewrite/test_symbol_tree.py" "len-as-condition" "mindspore/tests/ut/python/test_log.py" "possibly-unused-variable" "mindspore/tests/ut/python/test_log.py" "protected-access" "mindspore/tests/ut/python/train/summary/test_summary_collector.py" "protected-access" diff --git a/cmake/package.cmake b/cmake/package.cmake index 5797e387833..58459da1637 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -306,6 +306,7 @@ install( ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/communication ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/profiler ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/compression + ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/rewrite ${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/run_check DESTINATION ${INSTALL_PY_DIR} COMPONENT mindspore diff --git a/mindspore/python/mindspore/__init__.py b/mindspore/python/mindspore/__init__.py index 1d2f213b4e6..53c958f3733 100755 --- a/mindspore/python/mindspore/__init__.py +++ b/mindspore/python/mindspore/__init__.py @@ -26,6 +26,7 @@ from .context import GRAPH_MODE, PYNATIVE_MODE, set_context, get_context, set_au get_auto_parallel_context, reset_auto_parallel_context, ParallelMode, set_ps_context, \ get_ps_context, reset_ps_context, set_fl_context, get_fl_context from .version import __version__ +from .rewrite import * __all__ = ["run_check"] @@ -34,3 +35,4 @@ __all__.extend(common.__all__) __all__.extend(train.__all__) __all__.extend(log.__all__) __all__.extend(context.__all__) +__all__.extend(rewrite.__all__) diff --git a/mindspore/python/mindspore/rewrite/__init__.py b/mindspore/python/mindspore/rewrite/__init__.py new file mode 100644 index 00000000000..5cd216be31f --- /dev/null +++ b/mindspore/python/mindspore/rewrite/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +MindSpore Rewrite module. +This is an experimental python package that is subject to change or deletion. +""" +from .parsers.module_parser import g_module_parser +from .parsers.class_def_parser import g_classdef_parser +from .parsers.function_def_parser import g_functiondef_parser +from .parsers.arguments_parser import g_arguments_parser +from .parsers.assign_parser import g_assign_parser +from .parsers.return_parser import g_return_parser +from .api.scoped_value import ScopedValue, ValueType +from .api.symbol_tree import SymbolTree +from .api.node import Node +from .api.node_type import NodeType +from .api.pattern_engine import PatternEngine, PatternNode, VarNode, Replacement + +__all__ = ["SymbolTree", "Node", "NodeType", "ScopedValue", "ValueType", "PatternEngine", "PatternNode", "VarNode", + "Replacement"] diff --git a/mindspore/python/mindspore/rewrite/api/__init__.py b/mindspore/python/mindspore/rewrite/api/__init__.py new file mode 100644 index 00000000000..d391580016d --- /dev/null +++ b/mindspore/python/mindspore/rewrite/api/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +MindSpore Rewrite api. +""" diff --git a/mindspore/python/mindspore/rewrite/api/node.py b/mindspore/python/mindspore/rewrite/api/node.py new file mode 100644 index 00000000000..ca247a14a70 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/api/node.py @@ -0,0 +1,302 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Rewrite module api: Node.""" + +from typing import Union, Optional + +from mindspore.nn import Cell +from ..node import Node as NodeImpl +from ..symbol_tree import SymbolTree as SymbolTreeImpl +from .node_type import NodeType +from .scoped_value import ScopedValue + + +class Node: + """ + Node is a data structure represents a source code line in network. + + For the most part, Node represents an operator invoking in forward which could be an instance of `Cell`, an instance + of `Primitive` or a callable method. + + `NodeImpl` mentioned below is implementation of `Node` which is not an interface of Rewrite. Rewrite recommend + invoking specific create method of `Node` to instantiate an instance of Node such as `create_call_cell` rather than + invoking constructor of `Node` directly, so don't care about what is `NodeImpl` and use its instance just as a + handler. + + Args: + node (NodeImpl): A handler of `NodeImpl`. + """ + + def __init__(self, node: NodeImpl): + self._node = node + + def get_handler(self) -> NodeImpl: + """ + Get handler of node implementation. + + Returns: + An instance of `NodeImpl`. + """ + return self._node + + @staticmethod + def create_call_cell(cell: Cell, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None, + kwargs: {str: ScopedValue}=None, name: str = "") -> 'Node': + """ + Create a node. Only support create from a `Cell` now. + + A node is corresponding to source code like: + + .. code-block:: + + `targets` = self.`name`(*`args`, **`kwargs`) + + Args: + cell (Cell): Cell-operator of this forward-layer. + targets (list[ScopedValue]): Indicate output names. Used as targets of an assign statement in source code. + Rewrite will check and ensure the uniqueness of each target while node being inserted. + args (list[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in + source code. Default is None indicate the `cell` has no args inputs. Rewrite will check and ensure the + uniqueness of each arg while node being inserted. + kwargs (dict{str: ScopedValue}): Indicate keyword input names. Used as kwargs of a call expression of an + assign statement in source code. Default is None indicate the `cell` has no kwargs inputs. Rewrite will + check and ensure the uniqueness of each kwarg while node being inserted. + name (str): Indicate the name of node. Used as field name in source code. Default is None. Rewrite will + generate name from `targets` when name is None. Rewrite will check and ensure the uniqueness of `name` + while node being inserted. + + Returns: + An instance of `Node`. + + Raises: + RuntimeError: If `cell` is not a `Cell`. + RuntimeError: If `targets` is None. + RuntimeError: If target in `targets` is not a `NamingValue`-`ScopedValue`. + RuntimeError: If arg in `args` is not a `NamingValue`-`ScopedValue` or a `CustomObjValue`-`ScopedValue`. + RuntimeError: If value of kwarg in `kwargs` is not a `NamingValue`-`ScopedValue` or a + `CustomObjValue`-`ScopedValue`. + """ + return Node(NodeImpl.create_call_cell(cell, None, targets, ScopedValue.create_naming_value(name, "self"), args, + kwargs, name)) + + def get_prev(self) -> 'Node': + """ + Get previous node of current node in source code order. + + Returns: + An instance of `Node` as previous node. + """ + return Node(self._node.get_prev()) + + def get_next(self) -> 'Node': + """ + Get next node of current node in source code order. + + Returns: + An instance of `Node` as next node. + """ + return Node(self._node.get_next()) + + def get_inputs(self) -> ['Node']: + """ + Get input nodes of current node in topological order. + + Returns: + A list of instances of `Node` as input nodes. + """ + return [Node(node_impl) for node_impl in self._node.get_inputs()] + + def get_users(self) -> ['Node']: + """ + Get output nodes of current node in topological order. + + Returns: + A list of nodes represents users. + """ + belong_symbol_tree: SymbolTreeImpl = self._node.get_belong_symbol_tree() + if belong_symbol_tree is None: + return [] + unique_results = [] + for node_user in belong_symbol_tree.get_node_users(self._node): + node = node_user[0] + if node not in unique_results: + unique_results.append(node) + return [Node(node_impl) for node_impl in unique_results] + + def set_arg(self, index: int, arg: Union[ScopedValue, str]): + """ + Set argument of current node. + + Args: + index (int): Indicate which input being modified. + arg (Union[ScopedValue, str]): New argument to been set. + + Raises: + RuntimeError: If `index` is out of range. + RuntimeError: If `arg` a `NamingValue`-`ScopedValue` or a `CustomObjValue`-`ScopedValue` when `arg` is an + `ScopedValue`. + """ + belong_symbol_tree: SymbolTreeImpl = self._node.get_belong_symbol_tree() + if belong_symbol_tree is None: + self._node.set_arg(arg, index) + else: + belong_symbol_tree.set_node_arg(self._node, index, arg) + + def set_arg_by_node(self, arg_idx: int, src_node: 'Node', out_idx: Optional[int] = None): + """ + Set argument of current node by another `Node`. + + Args: + arg_idx (int): Indicate which input being modified. + src_node (Node): A `Node` as new input. Can be a node or name of node. + out_idx (int, optional): Indicate which output of `src_node` as new input of current node. Default is None + which means use first output of `src_node` as new input. + + Raises: + RuntimeError: If `src_node` is not belong to current `SymbolTree`. + RuntimeError: If current node and `src_node` is not belong to same `SymbolTree`. + RuntimeError: If `arg_idx` is out of range. + RuntimeError: If `out_idx` is out of range. + RuntimeError: If `src_node` has multi-outputs while `out_idx` is None or `out_idx` is not offered. + """ + belong_symbol_tree: SymbolTreeImpl = self._node.get_belong_symbol_tree() + if belong_symbol_tree is None: + self._node.set_arg_by_node(arg_idx, src_node._node, out_idx) + else: + belong_symbol_tree.set_node_arg_by_node(self._node, arg_idx, src_node._node, out_idx) + + def get_targets(self) -> [ScopedValue]: + """ + Get targets of current node. + + - When node_type of current node is `CallCell`, `CallPrimitive`, `CallMethod` or `Tree`, `targets` are strings + represents invoke result of the cell-op or primitive-op or function-call which are corresponding to targets of + ast.Assign. + - When node_type of current node is Input, `targets` should have only one element which is a string represents + parameter of function. + - When node_type of current node is `Python` or `Output`, `targets` are don't-care. + + Returns: + A list of instances of ScopedValue as targets of node. + """ + return self._node.get_targets() + + def get_name(self) -> str: + """ + Get the name of current node. + + When node has been inserted into `SymbolTree`, the name of node should be unique in `SymbolTree`. + + Returns: + A string as name of node. + """ + return self._node.get_name() + + def get_node_type(self) -> NodeType: + """ + Get the node_type of current node. + + Returns: + A NodeType as node_type of node. + """ + return self._node.get_node_type() + + def get_instance_type(self) -> type: + """ + Get the instance_type of current node. + + - When node_type of current node is `CallCell`, instance_type is type of cell-op. + - When node_type of current node is `CallPrimitive`, instance_type is type of primitive-op. + - When node_type of current node is `Tree`, instance_type is type of network-cell. + - When node_type of current node is `Python`, `Input`, `Output` or `CallMethod`, instance_type should be + NoneType. + + Returns: + A type object represents corresponding instance type of current node. + """ + return self._node.get_instance_type() + + def get_instance(self): + """ + Get the instance of current node. + + - When node_type of current node is `CallCell`, instance is an instance of Cell. + - When node_type of current node is `CallPrimitive`, instance is an instance of primitive. + - When node_type of current node is `Tree`, instance is an instance of network-cell. + - When node_type of current node is `Python`, `Input`, `Output` or `CallMethod`, instance should be None. + + Returns: + A object represents corresponding instance of current node. + """ + return self._node.get_instance() + + def get_args(self) -> [ScopedValue]: + """ + Get the arguments of current node. + + - When node_type of current node is `CallCell`, `CallPrimitive` or `Tree`, arguments are corresponding to args + of ast.Call which represents arguments to invoke forward method of cell-op or primitive-op. + - When node_type of current node is `Input`, arguments represents default-value of argument of function. + - When node_type of current node is `Output`, arguments represents return values. + - When node_type of current node is `Python`, arguments are don't-care. + + Returns: + A list of instances of `ScopedValue`. + """ + return self._node.get_args() + + def get_kwargs(self) -> {str: ScopedValue}: + """ + Get the keyword arguments of current node. + + - When node_type of current node is `CallCell`, `CallPrimitive` or `Tree`, keyword arguments are corresponding + to kwargs of ast.Call which represents arguments to invoke forward method of cell-op or primitive-op. + - When node_type of current node is `Python`, `Input` or `Output`, keyword arguments are don't-care. + + Returns: + A dict of str to instance of `ScopedValue`. + """ + return self._node.get_kwargs() + + def set_attribute(self, key: str, value): + """ + Set attribute of current node. + + Args: + key (str): Key of attribute. + value (object): Value of attribute. + """ + self._node.set_attribute(key, value) + + def get_attributes(self) -> {str: object}: + """ + Get all attributes of current node. + + Returns: + A dict of str to instance of object as attributes. + """ + return self._node.get_attributes() + + def get_attribute(self, key: str): + """ + Get attribute of current node by key. + + Returns: + A object as attribute. + """ + return self._node.get_attribute(key) + + def __eq__(self, other: 'Node'): + return self._node == other._node diff --git a/mindspore/python/mindspore/rewrite/api/node_type.py b/mindspore/python/mindspore/rewrite/api/node_type.py new file mode 100644 index 00000000000..bb6bbc563d3 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/api/node_type.py @@ -0,0 +1,43 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Rewrite module api: NodeType.""" +from enum import Enum + + +class NodeType(Enum): + """ + `NodeType` represents type of `Node`. + + - Unknown: Not inited NodeType. + - CallCell: `CallCell` node represents invoking cell-op in forward method. + - CallPrimitive: `CallPrimitive` node represents invoking primitive-op in forward method. + - CallMethod: `CallMethod` node represents invoking of method in forward method which can not be mapped to + cell-op or primitive-op in MindSpore. + - Python: `Python` node holds unsupported-ast-node or unnecessary-to-parse-ast-node. + - Input: `Input` node represents input of `SymbolTree` corresponding to arguments of forward method. + - Output: `Output` node represents output of SymbolTree corresponding to return statement of forward method. + - Tree: `Tree` node represents sub-network invoking in forward method. + + """ + Unknown = 0 + # Compute node type + CallCell = 1 + CallPrimitive = 2 + CallMethod = 3 + # Other node type + Python = 4 + Input = 5 + Output = 6 + Tree = 7 diff --git a/mindspore/python/mindspore/rewrite/api/pattern_engine.py b/mindspore/python/mindspore/rewrite/api/pattern_engine.py new file mode 100644 index 00000000000..b2875fef5d8 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/api/pattern_engine.py @@ -0,0 +1,413 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""PatternEngine for modifying SymbolTree by pattern.""" +from collections import OrderedDict +from typing import Tuple, Union, List, Type +import abc + +from mindspore.nn import Cell +from mindspore import log as logger +from .node_type import NodeType +from .node import Node +from .symbol_tree import SymbolTree + + +class PatternNode: + """ + `PatternNode` is defined as a node while defining pattern. + + Args: + pattern_node_name (str): Name of current node. + match_type (Type): A type represents what type would be matched of current node. + inputs (list[PatternNode]): Input nodes of current node. + """ + + def __init__(self, pattern_node_name: str, match_type: Type = Type[None], inputs: ['PatternNode'] = None): + self._name = pattern_node_name + self._type = match_type + if inputs is None: + self._inputs = [] + else: + self._inputs = inputs + + @staticmethod + def from_node(node: Node) -> 'PatternNode': + """ + Create a `PatternNode` from `node`. + + Args: + node (Node): Input rewrite node. + + Returns: + A `PatternNode` created from `node`. + """ + + pattern_node: PatternNode = PatternNode(node.get_targets()[0]) + if node.get_node_type() is NodeType.CallCell: + pattern_node._type = node.get_instance_type() + return pattern_node + + @staticmethod + def create_pattern_from_node(node: Node) -> 'PatternNode': + """ + Create a Pattern from `node` with its inputs. + + Args: + node (Node): Input rewrite node. + + Returns: + A `PatternNode` as root of pattern created from rewrite node. + """ + + pattern_node: PatternNode = PatternNode.from_node(node) + inputs = [] + for node_input in node.get_inputs(): + inputs.append(PatternNode.create_pattern_from_node(node_input)) + pattern_node._inputs = inputs + return pattern_node + + @staticmethod + def create_pattern_from_list(type_list: []) -> 'PatternNode': + """ + Create a Pattern from a cell type list. + + Args: + type_list (list[type]): Input cell type list. + + Returns: + A `PatternNode` as root of pattern created from cell type list. + """ + + last_node = None + for i, cell_type in enumerate(type_list): + cur_node: PatternNode = PatternNode(str(i) + "-" + str(cell_type), cell_type, []) + if last_node is not None: + cur_node._inputs = [last_node] + else: + cur_node._inputs = [] + last_node = cur_node + return last_node + + def add_input(self, node): + """ + Add an input for current `PatternNode`. + + Args: + node (PatternNode): Cell type as an input. + """ + + self._inputs.append(node) + + def set_inputs(self, inputs): + """ + Set inputs for current `PatternNode`. + + Args: + inputs (list[PatternNode]) : Inputs to be set as inputs of current `PatternNode`. + """ + + self._inputs = inputs + + def match(self, node: Node) -> bool: + """ + Check if current `PatternNode` can match with `node`. + + Args: + node (Node) : A rewrite node to be match. + """ + + return self._type == node.get_instance_type() + + def get_inputs(self): + """ + Getter of inputs. + """ + + return self._inputs + + def name(self) -> str: + """ + Getter of name. + """ + return self._name + + def type(self): + """ + Getter of type. + """ + return self._type + + +class VarNode(PatternNode): + """ + VarNode is a subclass of `PatternNode` whose `match` method is always return True. + """ + + def __init__(self): + super(VarNode, self).__init__("placeholder", Cell, []) + + def match(self, node: Node) -> bool: + return node is not None and node.get_handler() is not None + + +class Replacement(abc.ABC): + """ + Interface of replacement function. + """ + + @abc.abstractmethod + def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]: + """ + Interface define for creating replacement nodes from matched result. + + Note: + Return value will be delivered into replace api of `SymbolTree` as argument, return value should follow + restraint of parameter `new_nodes` of `replace` api if `SymbolTree`. See detail in docstring of `replace` + api of `SymbolTree`. + + Args: + pattern (PatternNode): A `PatternNode` represents root node of current pattern. + is_chain_pattern (bool): A bool indicated if pattern is a chain pattern or a tree pattern. + matched (OrderedDict): An OrderedDict map from pattern_node name to node represents matched result. + + Returns: + A list of instance of `Node` as replacement nodes. + """ + + raise NotImplementedError + + def __call__(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]: + return self.build(pattern, is_chain_pattern, matched) + + +class PatternEngine: + """ + `PatternEngine` is defined how to transform a `SymbolTree` by `PattenNode`. + + Args: + pattern (Union[PatternNode, List]): An instance of `PatternNode` or a cell-type-list to construct `PatternNode` + as root of a pattern. + replacement (callable): A callable define how to generate new_node. + """ + + def __init__(self, pattern: Union[PatternNode, List], replacement: Replacement = None): + if isinstance(pattern, PatternNode): + self._is_chain = False + self._replacement: Replacement = replacement + self._pattern: PatternNode = pattern + elif isinstance(pattern, list): + self._is_chain = True + self._replacement: Replacement = replacement + self._pattern: PatternNode = PatternNode.create_pattern_from_list(pattern) + else: + raise RuntimeError("Unsupported pattern define") + + def pattern(self) -> PatternNode: + """ + Getter of pattern. + """ + + return self._pattern + + @staticmethod + def _multi_to_multi_replace(stree: SymbolTree, old_root: Node, matched_dict: OrderedDict, + new_nodes: [Node]) -> Node: + """ + Replace multi-nodes in `stree` by another list of nodes. + + Note: + Call replace api of `SymbolTree`, so parameter `new_nodes` has same restraint with parameter `new_nodes` of + `replace` api if `SymbolTree`. See detail in docstring of `replace` api of `SymbolTree`. + + Args: + stree (SymbolTree): A `SymbolTree` which replacement will apply on. + old_root (Node): A `Node` represents root of original nodes. + matched_dict (OrderedDict): An instance of OrderedDict as match result, where key is the pattern name, value + is the matched node. + new_nodes (list[Node]): A list of instance of Node as replacement. + """ + to_erase_list = matched_dict.values() + # keep all old nodes' inputs + inputs_dict = {} + for node in to_erase_list: + inputs_dict[node.get_name()] = (node.get_inputs()) + # call replace of SymbolTree + new_root = stree.replace(old_root, new_nodes) + # replace only support one-to-one replace or one-to-multi replace, we need to erase nodes except + # cur_node manually + queue: [Node] = [old_root] + while queue: + cur_node: Node = queue.pop(0) + if cur_node in to_erase_list: + if cur_node.get_users(): + # if cur_node is depended on by other node, skip now. + # cur_node will be push into queue and be erased later + continue + if stree.get_node(cur_node.get_name()) is not None: + # cur_node is not erased before + stree.erase_node(cur_node) + queue.extend(inputs_dict.get(cur_node.get_name())) + return new_root + + def apply(self, stree: SymbolTree) -> bool: + """ + Apply current pattern to a `SymbolTree`. + + Note: + Sub-tree node will be supported in the near feature. + + Args: + stree (SymbolTree): A `SymbolTree` to be transformed. + + Returns: + A bool represents if `stree` been changed. + + Raises: + RuntimeError: If `SymbolTree` has no return node. + """ + + root: Node = stree.get_return_node() + if root is None: + raise RuntimeError("SymbolTree should be inited and has return node") + changed = False + # IR match + queue: [Node] = [root] + # Why need visited: we don't need or should not to visit same node multi-times because pre-visited node may + # already been erased from SymbolTree. + # When will we visit same node multi-times: + # a + # / \ + # / \ + # b c + # | | + # | d + # \ / + # \ / + # e + # 1. Visit e, e does not match pattern, add b, d to queue. + # 2. Visit b, b does not match pattern, add a to queue. + # 3. Visit d, d does not match pattern, add c to queue. + # 4. Visit a, a matches pattern and erased from SymbolTree, add xx to queue. + # 5. Visit c, d does not match pattern, add a to queue. + # At step 5, a is visited at second time but a is erased from SymbolTree at step 4. + visited: [Node] = [] + while queue: + cur_node: Node = queue.pop(0) + if cur_node is None: # Because inputs of node is allowed to be None in replacement. + continue + if cur_node in visited: + continue + visited.append(cur_node) + matched, matched_dict = self._match(self._pattern, cur_node) + # not matched + if not matched or not PatternEngine._check_match(self._pattern, matched_dict): + if cur_node is not None: + queue.extend(cur_node.get_inputs()) + continue + # matched + new_nodes: [Node] = [] + if self._replacement is not None: + new_nodes: [Node] = self._replacement(self._pattern, self._is_chain, matched_dict) + if not new_nodes: # if replacement is empty, do nothing + queue.extend(cur_node.get_inputs()) + else: # replace cur_node with new_nodes + changed = True + root = PatternEngine._multi_to_multi_replace(stree, cur_node, matched_dict, new_nodes) + queue.append(root) + return changed + + @staticmethod + def _merge_ordered_dict(dict1: OrderedDict, dict2: OrderedDict) -> OrderedDict: + """ + A static util method to merge two OrderedDict. + + Args: + dict1 (OrderedDict): First dict to be merged. + dict2 (OrderedDict): Second dict to be merged. + + Returns: + Merged OrderedDict. + """ + + merged = dict1.copy() + merged.update(dict2) + return merged + + def _match(self, pattern: PatternNode, node: Node) -> Tuple[bool, OrderedDict]: + """ + Match `pattern` with a `node` with all inputs of the `pattern`. + + Args: + pattern (PatternNode): Pattern to be match. + node (Node): Node to be match. + + Returns: + A bool value to indicate if matched. + An instance of OrderedDict as match result, where key is the pattern name, value is the matched node. + """ + + # Don't iterate into subgraph node, pattern should not be matched across sub-tree + if node.get_node_type() != NodeType.CallCell and node.get_node_type() != NodeType.Input: + logger.debug("Pattern match failed: node(%s) is not a cell", str(node)) + return False, OrderedDict() + if not pattern.match(node): + logger.debug("Pattern match failed: node(%s)'s type is %s while pattern type is %s", str(node), + node.get_instance_type(), pattern.type()) + return False, OrderedDict() + if isinstance(pattern, VarNode): + return True, OrderedDict() + pattern_inputs = pattern.get_inputs() + cur_inputs = node.get_inputs() + input_num = len(pattern_inputs) + if input_num == 0: + return True, OrderedDict({pattern.name(): node}) + if input_num != len(cur_inputs): + logger.debug("Pattern match failed: node(%s)'s has %d inputs while pattern has %d inputs", str(node), + len(node.get_inputs()), input_num) + return False, OrderedDict() + result = OrderedDict() + for i in range(0, input_num): + is_matched, tmp_result = self._match(pattern_inputs[i], cur_inputs[i]) + if not is_matched: + return False, OrderedDict() + result = PatternEngine._merge_ordered_dict(result, tmp_result) + result[pattern.name()] = node + return True, result + + @staticmethod + def _check_match(pattern: PatternNode, match_dict: OrderedDict) -> bool: + """ + Check if matched result is a leak result. + + A leak result means that the result is matched the `pattern`, but some nodes in result which is not + corresponding to root of pattern have outputs used by nodes outside of result. + + Args: + pattern (PatternNode): A `PatternNode` represents pattern to be match. + match_dict (OrderedDict): A OrderedDict represents matched result. + + Returns: + A bool value to indicate if matched result leaked. + """ + matched_nodes = match_dict.values() + for key in match_dict: + if key == pattern.name(): + continue + node: Node = match_dict[key] + for output in node.get_users(): + if output not in matched_nodes: + logger.debug("Check match failed, pattern leaked") + return False + return True diff --git a/mindspore/python/mindspore/rewrite/api/scoped_value.py b/mindspore/python/mindspore/rewrite/api/scoped_value.py new file mode 100644 index 00000000000..e191d2b9e2e --- /dev/null +++ b/mindspore/python/mindspore/rewrite/api/scoped_value.py @@ -0,0 +1,152 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Rewrite module api: ValueType and ScopedValue.""" +from enum import Enum +from typing import Optional + + +class ValueType(Enum): + """ + ValueType represents type of `ScopedValue`. + + - A `NamingValue` represents a reference to another variable. + - A `CustomObjValue` represents an instance of custom class or an object whose type is out of range of base-type + and container-type of ValueType. + """ + + # base type + StringValue = 0 + IntValue = 1 + FloatValue = 2 + # container type + TupleValue = 20 + ListValue = 21 + DictValue = 22 + # other type + NamingValue = 40 + CustomObjValue = 41 + + +class ScopedValue: + """ + `ScopedValue` represents a value with its full-scope. + + `ScopedValue` is used to express: a left-value such as target of an assign statement, or a callable object such as + func of a call statement, or a right-value such as args and kwargs of an assign statement. + + Args: + arg_type (ValueType): A `ValueType` represents type of current value. + scope (str): A string represents scope of current value. Take "self.var1" as an example, `scope` of this + var1 is "self". + value: A handler represents value of current value. The type of value is corresponding to `arg_type`. + """ + + def __init__(self, arg_type: ValueType, scope: str = "", value=None): + self.type = arg_type + self.scope = scope + self.value = value + + @classmethod + def create_variable_value(cls, value) -> Optional['ScopedValue']: + """ + Create `ScopedValue` from a variable. + + `ScopedValue`'s type is determined by type of value. `ScopedValue`'s scope is empty. + + Args: + value: The value to be converted to `ScopedValue`. + + Returns: + An instance of `ScopedValue`. + """ + if isinstance(value, int): + return cls(ValueType.IntValue, "", value) + if isinstance(value, float): + return cls(ValueType.FloatValue, "", value) + if isinstance(value, str): + return cls(ValueType.StringValue, "", value) + if isinstance(value, tuple): + return cls(ValueType.TupleValue, "", + tuple(cls.create_variable_value(single_value) for single_value in value)) + if isinstance(value, list): + return cls(ValueType.ListValue, "", list(cls.create_variable_value(single_value) for single_value in value)) + if isinstance(value, dict): + for key, _ in value.items(): + if not isinstance(key, str): + raise TypeError("key should be str, got: ", type(key)) + return cls(ValueType.DictValue, "", + dict((key, cls.create_variable_value(single_value)) for key, single_value in value.items())) + return cls(ValueType.CustomObjValue, "", value) + + @classmethod + def create_naming_value(cls, name: str, scope: str = "") -> 'ScopedValue': + """ + Create a naming `ScopedValue`. A `NamingValue` represents a reference to another variable. + + Args: + name: (str): A string represents the identifier of another variable. + scope: (str): A string represents the scope of another variable. + + Returns: + An instance of `ScopedValue`. + """ + return cls(ValueType.NamingValue, scope, name) + + @staticmethod + def create_name_values(names: [str], scopes: [str] = None) -> ['ScopedValue']: + """ + Create a list of naming `ScopedValue`. + + Args: + names: (list[str]): A list of string represents names of referenced variables. + scopes: (list[str]): A list of string represents scopes of referenced variables. + + Returns: + An list of instance of `ScopedValue`. + + Raise: + RuntimeError: If the length of names is not equal to the length of scopes when scopes are not None. + """ + if scopes is not None: + if len(names) != len(scopes): + raise RuntimeError("Length of names should be equal to length of scopes") + result = [] + for index, name in enumerate(names): + if scopes is not None: + scope = scopes[index] + else: + scope = "" + result.append(ScopedValue.create_naming_value(name, scope)) + return result + + def __str__(self): + if self.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue): + return str(self.value) + if self.type == ValueType.NamingValue: + return f"{self.scope}.{self.value}" if self.scope else str(self.value) + if self.type == ValueType.CustomObjValue: + return f"CustomObj: {str(self.value)}" + return f"Illegal ValueType: {str(self.type)}" + + def __eq__(self, other): + if id(self) == id(other): + return True + return self.type == other.type and self.scope == other.scope and self.value == other.value + + def __repr__(self): + return str(self) + + def __hash__(self): + return hash((self.type, self.scope, self.value)) diff --git a/mindspore/python/mindspore/rewrite/api/symbol_tree.py b/mindspore/python/mindspore/rewrite/api/symbol_tree.py new file mode 100644 index 00000000000..fc2d3e2ca57 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/api/symbol_tree.py @@ -0,0 +1,227 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Rewrite module api: SymbolTree.""" +from typing import Optional + +from mindspore.nn import Cell +from .node import Node +from ..symbol_tree_builder import SymbolTreeBuilder +from ..symbol_tree import SymbolTree as SymbolTreeImpl + + +class SymbolTree: + """ + A `SymbolTree` usually corresponding to forward method of a network. + + Args: + network (Cell): Network to be rewritten. Only support `Cell`-type-network now. + + Raises: + RuntimeError: If `network` is not a Cell. + RuntimeError: If there is any unsupported ast node type while parsing or optimizing. + """ + + def __init__(self, network: Cell): + self._symbol_tree: SymbolTreeImpl = SymbolTreeBuilder(network).build() + + def get_handler(self) -> SymbolTreeImpl: + """ + Get handler of `SymbolTree` implementation. + + Returns: + An instance of `SymbolTree`. + """ + return self._symbol_tree + + def nodes(self) -> {}: + """ + Get all nodes of corresponding network. + + Returns: + A dict mapping from name of node to node. + """ + return [Node(node_impl) for node_impl in self._symbol_tree.nodes(unfold_subtree=True)] + + def get_node(self, node_name: str) -> Optional[Node]: + """ + Get node by `node_name`. + + Args: + node_name (str): A string represents name of node. + + Returns: + An instance of node if find else None. + """ + node_impl = self._symbol_tree.get_node(node_name) + if node_impl is None: + return None + return Node(node_impl) + + def get_return_node(self) -> Node: + """ + Get return node of current `SymbolTree`. + + A `SymbolTree` can and should have one return node corresponding to return statement of forward method of + network. + + Returns: + An instance of node represents return node. + """ + ret = self._symbol_tree.get_return_node() + if ret is None: + raise RuntimeError("SymbolTree is not well inited, can not find return node.") + return Node(ret) + + def before(self, node: Node): + """ + Get insert position before input `node`. + + `Position` is used to indicate where to insert node, it indicates position in source code rather than position + in topological order. We don't need to care about what `Position` is, just treat it as a handler and use it as + an arguments of `insert` api of `SymbolTree`. + + Args: + node (Node): Indicate the position before which node. Can be a node or name of node. + + Returns: + A `Position` to indicate where to insert node. + + Raises: + RuntimeError: if `node` is not a Node or a string. + """ + return self._symbol_tree.before(node.get_handler()) + + def after(self, node: Node): + """ + Get insert position after input `node`. + + `Position` is used to indicate where to insert node, it indicates position in source code rather than position + in topological order. We don't need to care about what `Position` is, just treat it as a handler and use it as + an arguments of `insert` api of `SymbolTree`. + + Args: + node (Node): Indicate the position after which node. Can be a node or name of node. + + Returns: + A `Position` to indicate where to insert node. + + Raises: + RuntimeError: If `node` is not a Node. + """ + return self._symbol_tree.after(node.get_handler()) + + def insert(self, position, node: Node) -> Node: + """ + Insert a `node` into `SymbolTree` at `position`. + + `position` is obtained from `before` api or `after` api of `SymbolTree`. + + Args: + position (Position): Indicate where to insert `node`. + node (Node): An instance of Node to be inserted. + + Returns: + An instance of Node being inserted. `node` could be changed while calling this method for uniqueness and + custom-object in args or kwargs. + + Raises: + RuntimeError: If `position` is not belong to current `SymbolTree`. + """ + return Node(self._symbol_tree.insert_node(position, node.get_handler())) + + def erase_node(self, node: Node) -> Optional[Node]: + """ + Erase a `node` from rewrite. Can only erase a node not being depended on. + + Args: + node (Node): A `Node` to be erased. Can be a node or name of node. + + Returns: + An instance of `Node` being erased if node is in `SymbolTree` else None. + + Raises: + RuntimeError: If `node` is not a `Node`. + """ + return Node(self._symbol_tree.erase_node(node.get_handler())) + + def replace(self, old_node: Node, new_nodes: [Node]) -> Node: + """ + Replace `old_node` with a node_tree. + + Note: + + 1. Replace support one-to-one replacement or one-to-multi replacement. If you need multi-to-multi + replacement, please refer to `PatternEngine`. + 2. When applying one-to-multi replacement, Rewrite will insert all `new_nodes` into symbol_tree. + 3. Caller should maintain arguments and targets of nodes intra sub-tree for specifying topological relation + intra sub-tree. + 4. Caller should maintain arguments of input nodes of sub-tree and for specifying topological relation of + inputs of sub-tree. + 5. Rewrite will maintain arguments of prepend node of sub-tree for specifying topological relation of + outputs of sub-tree. + 6. Rewrite will maintain all inputs of nodes after replace `new_nodes` into `SymbolTree`. + + Args: + old_node (Node): Node to be replaced. + new_nodes (list[Node]): Nodes of the node_tree to replace in. + + Returns: + An instance of Node represents root of node_tree been replaced in. + + Raises: + RuntimeError: Old node is isolated. + """ + nodes_impl = [node.get_handler() for node in new_nodes] + return Node(self._symbol_tree.replace(old_node.get_handler(), nodes_impl)) + + def set_output(self, index: int, return_value: str) -> Node: + """ + Set return value of network. + + Args: + index (int): Indicate which output being modified. + return_value (str): New return value to been set. + + Returns: + Return node of current rewrite. + + Raises: + RuntimeError: If `index` is out of range. + """ + return Node(self._symbol_tree.set_output(return_value, index)) + + def dump(self): + """ + Dump graph to console. + """ + self._symbol_tree.dump() + + def get_code(self) -> str: + """ + Get source code of modified network. + + Returns: + A str represents source code of modified network. + """ + return self._symbol_tree.get_code() + + def get_network(self) -> Cell: + """ + Get modified network. + + Returns: + A network object. + """ + return self._symbol_tree.get_network() diff --git a/mindspore/python/mindspore/rewrite/ast_modifier.py b/mindspore/python/mindspore/rewrite/ast_modifier.py new file mode 100644 index 00000000000..d814c126884 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/ast_modifier.py @@ -0,0 +1,292 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Ast utils for create or update ast node.""" +from typing import Optional +import ast + +from .api.scoped_value import ScopedValue, ValueType + + +class AstModifier(ast.NodeTransformer): + """Ast utils for create or update ast node.""" + + @staticmethod + def erase_ast_from_function(ast_func: ast.FunctionDef, to_erase: ast.AST) -> bool: + """ + Erase ast node from ast.FunctionDef. + + Args: + ast_func (ast.FunctionDef): From which to search to_erase-node and erase. + to_erase (ast.AST): Node to be erased. + + Returns: + A bool if to_erase-node been found and been erased. + """ + for body in ast_func.body: + if id(body) == id(to_erase): + ast_func.body.remove(body) + return True + return False + + @staticmethod + def insert_assign_to_function(ast_func: ast.FunctionDef, targets: [ScopedValue], expr: ScopedValue, + args: [ScopedValue] = None, kwargs: {str, ScopedValue}=None, + index_ast: Optional[ast.AST] = None, insert_before=True) -> ast.AST: + """ + Insert an ast.Assign into an ast.FunctionDef. + + Args: + ast_func (ast.FunctionDef): Where new ast.Assign to be inserted into. + targets ([ScopedValue]): Targets of ast.Assign. + expr (ScopedValue): Func of ast.Call which is value of new ast.Assign. + args ([ScopedValue]): Args of ast.Call which is value of new ast.Assign. + kwargs ({str, ScopedValue}): Kwargs of ast.Call which is value of new ast.Assign. + index_ast (Optional[ast.AST]): An ast_node indicates a position in 'ast_func' where new ast.Assign node to + be inserted into. Default is None which means append new ast.Assign into + 'ast_func'. + insert_before (bool): A bool indicates at before or at after of 'index_ast' where new ast.Assign node to be + inserted into. Only valid when 'index_ast' is not None. Default is True which means + inserting new ast.Assign before 'index_ast'. + + Returns: + An instance of ast.Assign which has been inserted into 'ast_func'. + + Raises: + RuntimeError: If 'index_ast' is not contained in 'ast_func'. + """ + assign = AstModifier.create_call_assign(targets, expr, args, kwargs) + return AstModifier.insert_assign_ast_to_function(ast_func, assign, index_ast, insert_before) + + @staticmethod + def insert_assign_ast_to_function(ast_func: ast.FunctionDef, ast_assign: ast.Assign, + index_ast: Optional[ast.AST] = None, insert_before=True) -> ast.AST: + """ + Insert an ast.Assign into an ast.FunctionDef. + + Args: + ast_func (ast.FunctionDef): Where new ast.Assign to be inserted into. + ast_assign (ast.Assign): An instance of ast.Assign to be inserted in. + index_ast (Optional[ast.AST]): An ast_node indicates a position in 'ast_func' where new ast.Assign node to + be inserted into. Default is None which means append new ast.Assign to + 'ast_func'. + insert_before (bool): A bool indicates at before or at after of 'index_ast' where new ast.Assign node to be + inserted into. Only valid when 'index_ast' is not None. Default is True which means + inserting new ast.Assign before 'index_ast'. + + Returns: + An instance of ast.Assign which has been inserted into 'ast_func'. + + Raises: + RuntimeError: If 'index_ast' is not contained in 'ast_func'. + """ + if index_ast is None: + ast_func.body.append(ast_assign) + ast.fix_missing_locations(ast_func) + return ast_assign + for index in range(0, len(ast_func.body)): + if id(ast_func.body[index]) == id(index_ast): + if insert_before: + ast_func.body.insert(index, ast_assign) + else: + ast_func.body.insert(index + 1, ast_assign) + ast.fix_missing_locations(ast_func) + return ast_assign + raise RuntimeError("index_ast is not contained in ast_func") + + @staticmethod + def append_global_vars_expr_to_init(init_func: ast.FunctionDef, targets: [ScopedValue], + field: str) -> ast.AST: + """ + Append an ast.Assign to an ast.FunctionDef which is function named "__init__" in network. Value of new + ast.Assign is an ast.Call represents get an object from global_vars dict. + + While user inserting a custom op, the instance of new custom op is saved in a dict named global_vars. Rewrite + need to get the custom op instance from global_vars in new "__init__" function of network: + self.var1 = global_vars.get("var1") + + Args: + init_func (ast.FunctionDef): An instance of ast.FunctionDef which is "__init__" function of network. + targets ([ScopedValue]): Targets of ast.Assign. + field (str): A string represents name of new custom op field. + + Returns: + An instance of ast.Assign which has been appended to 'init_func'. + """ + return AstModifier.insert_assign_to_function(init_func, targets=targets, + args=[ScopedValue.create_variable_value(field)], + expr=ScopedValue(ValueType.NamingValue, "global_vars", "get")) + + @staticmethod + def create_call_assign(targets: [ScopedValue], expr: ScopedValue, args: [ScopedValue], + kwargs: {str, ScopedValue}) -> ast.Assign: + """ + Create an instance of ast.Assign whose value must ba a ast.Call. + + Args: + targets ([ScopedValue]): Targets of ast.Assign. + expr (ScopedValue): Func of ast.Call which is value of new ast.Assign. + args ([ScopedValue]): Args of ast.Call which is value of new ast.Assign. + kwargs ({str, ScopedValue}): Kwargs of ast.Call which is value of new ast.Assign. + + Returns: + An instance of ast.Assign. + + Raises: + RuntimeError: If 'targets' is None. + RuntimeError: If value_type of element of 'targets' is not ValueType.NamingValue. + + RuntimeError: If length of 'targets' is not 1. Multi-targets will be support in the future. + """ + if targets is None or len(targets) != 1: + raise RuntimeError("Only support one target in insert_cell_to_init now") + if targets[0].type != ValueType.NamingValue: + raise RuntimeError("Target must be a right-value, got: ", targets[0]) + if targets[0].scope: + ast_target = ast.Attribute(ast.Name(targets[0].scope, ast.Load()), targets[0].value, ast.Store()) + else: + ast_target = ast.Name(targets[0].value, ast.Store()) + call = AstModifier.create_call(expr, args, kwargs) + result = ast.Assign(targets=[ast_target], value=call) + ast.fix_missing_locations(result) + return result + + @staticmethod + def create_call(expr: ScopedValue, args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None) -> ast.Call: + """ + Create an instance of ast.Call. + + Args: + expr (ScopedValue): Func of ast.Call. + args ([ScopedValue]): Args of ast.Call. + kwargs ({str, ScopedValue}): Kwargs of ast.Call. + + Returns: + An instance of ast.Call. + + Raises: + RuntimeError: If value_type of 'expr' is ValueType.CustomObjValue. + RuntimeError: If value_type of 'expr' is not ValueType.NamingValue. + RuntimeError: If value_type of element of 'args' is ValueType.CustomObjValue. + RuntimeError: If value_type of value of 'kwargs' is ValueType.CustomObjValue. + TypeError: If expr is not an instance of ScopedValue. + RuntimeError: If element of 'args' is not an instance of ScopedValue. + RuntimeError: If value of 'kwargs' is not an instance of ScopedValue. + """ + if not isinstance(expr, ScopedValue): + raise TypeError("expr should be ScopedValue, got: ", type(expr)) + if expr.type == ValueType.CustomObjValue: + raise RuntimeError("Please handle custom-object first") + if expr.type != ValueType.NamingValue: + raise RuntimeError("Expr must not be a constant, because constant can not been called: ", expr.type) + if expr.scope: + ast_func = ast.Attribute(ast.Name(expr.scope, ast.Load()), expr.value, ast.Store()) + else: + ast_func = ast.Name(expr.value, ast.Store()) + + ast_args = [] + if args is not None: + for arg in args: + if not isinstance(arg, ScopedValue): + raise TypeError("arg should be ScopedValue, got: ", type(arg)) + if arg.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue): + if arg.scope: + raise RuntimeError("arg.scope should be empty") + ast_args.append(ast.Constant(value=arg.value, kind=None)) + elif arg.type == ValueType.NamingValue: + if arg.scope: + ast_args.append(ast.Attribute(ast.Name(arg.scope, ast.Load()), arg.value, ast.Store())) + else: + ast_args.append(ast.Name(arg.value, ast.Store())) + else: + raise RuntimeError("Please handle custom-object first") + keywords = [] + if kwargs is not None: + for arg, value in kwargs.items(): + if not isinstance(value, ScopedValue): + raise TypeError("value should be ScopedValue, got: ", type(value)) + if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue): + if value.scope: + raise RuntimeError("value.scope should be empty") + keywords.append(ast.keyword(arg=arg, value=ast.Constant(value=value.value, kind=None))) + elif value.type == ValueType.NamingValue: + if value.scope: + keywords.append(ast.keyword(arg=arg, + value=ast.Attribute(ast.Name(value.scope, ast.Load()), value.value, + ast.Store()))) + else: + keywords.append(ast.keyword(arg=arg, value=ast.Name(value.value, ast.Store()))) + else: + raise RuntimeError("Please handle custom-object first") + result = ast.Call(func=ast_func, args=ast_args, keywords=keywords) + ast.fix_missing_locations(result) + return result + + @staticmethod + def update_arg_value(src_argument: ScopedValue, dst_ast: ast.AST): + """ + Update 'arg_value' by 'input_argument' + + Args: + src_argument (ScopedValue): An instance of ScopedValue represents new argument. + dst_ast (ast.AST): Targets of ast.Assign. + + Raises: + TypeError: Input src_argument is not a ScopedValue + RuntimeError: If 'dst_ast' is an instance of ast.Constant but type of 'src_argument' is not + ValueType.IntValue, ValueType.FloatValue or ValueType.StringValue. + RuntimeError: If 'dst_ast' is an instance of ast.Name or ast.Attribute but type of 'src_argument' is not + ValueType.NamingValue. + RuntimeError: When 'dst_ast' is an instance of ast.Name, scope of 'src_argument' is not empty. + RuntimeError: When 'dst_ast' is an instance of ast.Attribute, value of 'dst_ast' is not an instance of + ast.Name. + RuntimeError: If 'dst_ast' is an instance of ast.Tuple but type of 'src_argument' is not + ValueType.TupleValue. + RuntimeError: If 'dst_ast' is an instance of ast.Constant, ast.Name, ast.Attribute or ast.Tuple. + RuntimeError: When 'dst_ast' is an instance of ast.Tuple, length of elts of 'dst_ast' is not equal to length + of value of 'src_argument'. + """ + if not isinstance(src_argument, ScopedValue): + raise TypeError("src_argument should be ScopedValue, got: ", type(src_argument)) + if isinstance(dst_ast, ast.Constant): + if src_argument.type not in [ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue]: + raise RuntimeError("src_argument should be a IntValue, FloatValue or StringValue, got:", + str(src_argument.type)) + dst_ast.value = src_argument.value + return + if isinstance(dst_ast, ast.Name): + if src_argument.type != ValueType.NamingValue: + raise RuntimeError("src_argument.type should equal to ValueType.NamingValue") + if src_argument.scope: + raise RuntimeError("src_argument.scope should be empty") + dst_ast.id = src_argument.value + return + if isinstance(dst_ast, ast.Attribute): + if src_argument.type != ValueType.NamingValue: + raise RuntimeError("src_argument.type should equal to ValueType.NamingValue") + attr_value = dst_ast.value + if not isinstance(attr_value, ast.Name): + raise RuntimeError("Only support ast.Name as value of attribute ", type(attr_value)) + attr_value.id = src_argument.scope + dst_ast.attr = src_argument.value + return + if isinstance(dst_ast, ast.Tuple): + if src_argument.type != ValueType.TupleValue: + raise RuntimeError("src_argument.type should equal to ValueType.TupleValue") + if len(src_argument.value) != len(dst_ast.elts): + raise RuntimeError("src_argument.value and elts in ast should have same length") + for elt_index, elt in enumerate(dst_ast.elts): + AstModifier.update_arg_value(src_argument.value[elt_index], elt) + return + raise RuntimeError("keyword value type is not supported", type(dst_ast)) diff --git a/mindspore/python/mindspore/rewrite/ast_transformers/__init__.py b/mindspore/python/mindspore/rewrite/ast_transformers/__init__.py new file mode 100644 index 00000000000..59ce74f2ab2 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/ast_transformers/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Transformers for optimizing ast.""" +from .flatten_recursive_stmt import FlattenRecursiveStmt diff --git a/mindspore/python/mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py b/mindspore/python/mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py new file mode 100644 index 00000000000..87c65cadb20 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py @@ -0,0 +1,154 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Ast optimizer for flatten recursive call.""" +from typing import Any, Tuple +import ast +from ast import FunctionDef +from mindspore import log as logger + + +class FlattenRecursiveStmt(ast.NodeTransformer): + """Ast optimizer for flatten recursive call.""" + + def __init__(self): + """ + Constructor of FlattenRecursiveStmt. + + Returns: + An instance of ast optimizer for flatten recursive call. + """ + self._flatten_table: dict = { + ast.Return: ["value"], + ast.Call: ["args"], + ast.BinOp: ["left", "right"], + ast.BoolOp: ["values"], + ast.unaryop: ["operand"], + } + + @staticmethod + def _generate_target_name(node: ast.AST, target_names): + """Generate unique target name.""" + if isinstance(node, ast.Call): + func = node.func + if isinstance(func, ast.Name): + target_name = func.id + elif isinstance(func, ast.Attribute): + target_name = func.attr + else: + logger.warning("unhandled type of func of ast.Call while generating new target name: %s ", type(func)) + target_name = "function" + elif isinstance(node, ast.Return): + target_name = "return_value" + elif isinstance(node, (ast.BinOp, ast.boolop, ast.UnaryOp)): + target_name = type(node.op).__name__ + else: + logger.warning("unhandled type of node while generating new target name: %s ", type(node)) + target_name = type(node).__name__ + suffix = 0 + result = target_name + while result in target_names: + suffix += 1 + result = f"{target_name}_{suffix}" + target_names.append(result) + return result + + @staticmethod + def _fill_in_original_target_names(target_names, node): + """Fill in original target names before getting unique names.""" + for function_index in range(len(node.body)): + child = node.body[function_index] + if not isinstance(child, ast.Assign): + continue + targets = child.targets + for target in targets: + if not isinstance(target, ast.Name): + raise RuntimeError("currently only support ast.Name targets") + target_name = target.id + if target_name not in target_names: + target_names.append(target_name) + + @staticmethod + def _create_new_assign_node(node: ast.AST, target_names) -> Tuple[str, ast.AST]: + """Create new assign node to be inserted into ast.FunctionDef.""" + if isinstance(node, (ast.Name, ast.Constant, ast.Num, ast.Str, ast.NameConstant, ast.Bytes, ast.Ellipsis)): + return "", node + new_target_name = FlattenRecursiveStmt._generate_target_name(node, target_names) + return new_target_name, ast.Assign(targets=[ast.Name(id=new_target_name, ctx=ast.Store())], value=node) + + def _flatten_statement(self, node: ast.AST, target_names) -> [ast.AST]: + """Flatten recursive statement according to different node type.""" + flatten_config = self._flatten_table.get(type(node)) + if flatten_config is None: + return [] + results = [] + for todo_name in flatten_config: + todos = getattr(node, todo_name) + if isinstance(todos, list): + new_list = [] + for todo in todos: + new_target_name, new_node = FlattenRecursiveStmt._create_new_assign_node(todo, target_names) + if id(new_node) == id(todo): + new_list.append(todo) + else: + new_list.append(ast.Name(id=new_target_name, ctx=ast.Load())) + results.append(new_node) + setattr(node, todo_name, new_list) + elif isinstance(todos, dict): + new_dict = [] + for key, value in todos: + new_target_name, new_node = FlattenRecursiveStmt._create_new_assign_node(value, target_names) + if id(new_node) == id(value): + new_dict[key] = value + else: + new_dict[key] = ast.Name(id=new_target_name, ctx=ast.Load()) + results.append(new_node) + setattr(node, todo_name, new_dict) + else: + new_target_name, new_node = FlattenRecursiveStmt._create_new_assign_node(todos, target_names) + if id(new_node) != id(todos): + setattr(node, todo_name, ast.Name(id=new_target_name, ctx=ast.Load())) + results.append(new_node) + return results + + def visit_FunctionDef(self, node: FunctionDef) -> Any: + """Traverse construct node and flatten recursive nodes.""" + if node.name != "construct": + return node + + target_names = [] + self._fill_in_original_target_names(target_names, node) + index = len(node.body) - 1 + while index >= 0: + child = node.body[index] + if isinstance(child, ast.Assign): + stmt = child.value + elif isinstance(child, ast.Expr): + stmt = child.value + else: + stmt = child + results = self._flatten_statement(stmt, target_names) + if results: + results.reverse() + for result in results: + node.body.insert(index, result) + index += 1 + index -= 1 + return node + + def transform(self, ast_root): + """Interface of FlattenRecursiveStmt.""" + ast_root = self.visit(ast_root) + ast_root = ast.fix_missing_locations(ast_root) + return ast_root diff --git a/mindspore/python/mindspore/rewrite/namer.py b/mindspore/python/mindspore/rewrite/namer.py new file mode 100644 index 00000000000..1bc7e453a56 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/namer.py @@ -0,0 +1,159 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Unique name producer for target, name of node.""" +from typing import Union + +from .node import Node +from .api.node_type import NodeType + + +class Namer: + """ + Used for unique identity in a class-scope. current used for target of construct-function. + Namer records times of name been used, and add prefix to origin name for unique name. For example, when a Namer + record "name1" has been used 10 times, when a new request require a unique name base on 'name1', namer will respond + "name1_10" as unique name. + """ + + def __init__(self): + """Constructor of Namer.""" + self._names: {str: int} = {} + + @staticmethod + def _real_name(name: str) -> str: + """ + Find real name. For example, "name1" is the real name of "name1_10", "name1" is the real name of "name1_10_3". + If not find real name before find unique name, unique name may be not unique. For example: + + 1. "name1" has been used 10 times, which means "name1", "name1_2", "name1_3" ... "name1_10" has been used; + 2. new request require a unique name base on 'name1_5' + 3. If namer not find real name of "name1_5", namer will find that "name1_5" is never used and respond + "name1_5" as unique name which is used before, actually. + + Args: + name (str): Origin name which may have digit prefix. + + Returns: + A string represents real-name. + """ + pos = name.rfind("_") + if pos == -1: + return name + digit = True + for i in range(pos + 1, len(name)): + if not name[i].isdigit(): + digit = False + break + if digit: + return Namer._real_name(name[:pos]) + return name + + def get_name(self, origin_name: str) -> str: + """ + Get unique name from 'origin_name'. + + Args: + origin_name (str): Origin name which may be duplicated. + + Returns: + A string represents unique-name. + """ + origin_name = Namer._real_name(origin_name) + number = self._names.get(origin_name) + if number is None: + self._names[origin_name] = 1 + return origin_name + self._names[origin_name] = number + 1 + return f"{origin_name}_{number}" + + def add_name(self, name: str): + """ + Add a name to Namer which should be unique. + + Args: + name (str): A name should be unique in current namer. + + Raises: + RuntimeError: If name is not unique in current namer. + """ + real_name = Namer._real_name(name) + number = self._names.get(real_name) + if number is not None: + raise RuntimeError("name duplicated: ", name) + self._names[name] = 1 + + +class TargetNamer(Namer): + """ + Used for unique-ing targets of node. + """ + def get_real_arg(self, origin_arg: str) -> str: + """ + Get real argument from 'origin_arg' because target of node which produces 'origin_arg' may be change for unique. + + Args: + origin_arg (str): Origin argument string which may be undefined cause to target unique-lize. + + Returns: + A string represents real argument name. + """ + num = self._names.get(origin_arg) + if num is None or num == 1: + return origin_arg + return f"{origin_arg}_{num - 1}" + + +class NodeNamer(Namer): + """ + Used for unique-ing node-name which is also used as field of init-function and key of global_vars + """ + + def get_name(self, node_or_name: Union[Node, str]) -> str: + """ + Override get_name in Namer class. + Get unique node_name from 'origin_name' or an instance of node. + + Args: + node_or_name (Union[Node, str]): A string represents candidate node_name or an instance of node who require + A unique node_name. + + Returns: + A string represents unique node_name. + """ + if isinstance(node_or_name, Node): + origin_name = node_or_name.get_name() + if origin_name is None or not origin_name: + if node_or_name.get_node_type() == NodeType.CallCell: + if not isinstance(node_or_name, Node): + raise TypeError("node_or_name should be Node, got: ", type(node_or_name)) + targets = node_or_name.get_targets() + # return node and head node will not call this method + if not targets: + raise RuntimeError("node should has at lease one target except return-node and head-node: ", + node_or_name) + origin_name = str(targets[0].value) + elif node_or_name.get_node_type() == NodeType.Python: + origin_name = node_or_name.get_instance().__name__ + elif node_or_name.get_node_type() == NodeType.Input: + origin_name = "parameter" + else: + raise RuntimeError("Node type unsupported:", node_or_name.get_node_type()) + elif isinstance(node_or_name, str): + if not node_or_name: + raise RuntimeError("input node_name is empty.") + origin_name = node_or_name + else: + raise RuntimeError("unexpected type of node_or_name: ", type(node_or_name)) + return super(NodeNamer, self).get_name(origin_name) diff --git a/mindspore/python/mindspore/rewrite/node.py b/mindspore/python/mindspore/rewrite/node.py new file mode 100644 index 00000000000..be9e86524f9 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/node.py @@ -0,0 +1,1092 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Node class define of Rewrite. See detail in Node class docstring.""" +from typing import Optional, Union +import ast +import inspect + +from mindspore.nn import Cell +from mindspore import log as logger +from .ast_modifier import AstModifier +from .api.scoped_value import ScopedValue, ValueType +from .api.node_type import NodeType + +PASS_THROUGH_METHOD = ScopedValue.create_naming_value("PassThrough") + + +class Node: + """ + Node is a data structure represents a source code line in network. For the most part, Node represents an operator + invoking in forward which could be an instance of Cell, an instance of Primitive or a callable method. Fields of + Node has different meaning in different type of node: + + - CallCell: a call-cell node represents an assign statement whose value is a calling to cell in mindspore. `targets` + is corresponding to targets of ast.Assign which means return values of this cell-op. `args` and `kwargs` are + corresponding to args and keywords of ast.Call which mean arguments to invoke cell-op's forward method. `func` is + corresponding to func of call expression which means symbol of the cell-op. + - CallPrimitive: a call-primitive node represents an ast.Assign whose value is a calling to operator in mindspore. + `targets`, `args`, `kwargs` and `func` are as previous. + - CallMethod: a call-method node represents an ast.Assign whose value is a calling to python-method such as `len`. + `targets` is corresponding to targets of ast.Assign which means return values of this method. `func` represents + the string name of method. `args` and `kwargs` are corresponding to args and keywords to invoke the method. When + value of ast.Assign is an ast.Name or ast.Attribute, it means a simplest assign which would also be mapped to + CallMethod node whose `func` is "PassThrough". + - GetAttr: retrieves a parameter from the SymbolTree hierarchy. `func` represents which parameter in SymbolTree + hierarchy. `targets` is corresponding to targets of ast.Assign which means what symbol to accept the result of + get-attr. `args` and `kwargs` are don't-care. + - Python: a python node holds an ast-node which is not parsed. a python node means some python statement is not + supported by Rewrite or ignored by Rewrite. `targets`, `args`, `kwargs` and `func` are don't-care. + - Input: an input node represents an input of current network which also a parameter of forward method of Cell. + `targets` is corresponding to arg-name of parameter of forward function. `args` means default-value of parameter + of forward function. `kwargs` and `func` are don't-care. + - Output: an output node represents the output of current network which is corresponding to return statement of + forward method of Cell. `args` represents return values. `func` are always be "return". `targets` and `kwargs` are + don't-care. + - Tree: a tree node represents a sub-network call in current network. A sub-network is also a Cell in mindspore, so + `targets`, `args`, `kwargs` and `func` are same as a call-cell node. `symbol_tree` is a handler of a SymbolTree + instance. + """ + + def __init__(self, node_type: NodeType, ast_node: Optional[ast.AST], targets: [ScopedValue], + func: Optional[ScopedValue], args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance): + """ + Constructor of Node. Rewrite recommend invoking class method of Node to instantiate an instance of Node such + as `create_call_cell`, `create_call_method`, `create_python_node`, `create_input_node` and `create_output_node`, + etc. rather than invoking constructor of Node directly. + + Args: + node_type (NodeType): A NodeType as type of Node. + ast_node (Optional[ast.AST]): An instance of ast.AST represents corresponding node in ast. `ast_node` should + not be None except when node type is Unknown. + targets ([ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. + func (Optional[ScopedValue]): An instance of ScopedValue. See detail in docstring of Node class. + args ([ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. + kwargs ({str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class. + name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree. + Name of node also used as field name in network class. + instance: Object in network corresponding to this node. + """ + self._node_type: NodeType = node_type + self._ast_node: Optional[ast.AST] = ast_node + self._attribute: {str, object} = Node._get_cell_or_prim_op_attribute(instance) + self._instance = instance + self._name = name + self._func: Optional[ScopedValue] = func + self._targets: [ScopedValue] = targets + self._args_num = len(args) + self._kwargs_num = len(kwargs) + self._normalized_args_keys = [] # for saving args' order + self._normalized_args = self._get_normalized_args(args, kwargs) + # edge of node + self._inputs: [Node] = [] + # position in graph nodes list + # it will affect code-order of python code + self._prev: Optional[Node] = None + self._next: Optional[Node] = None + # A handler of SymbolTree current node belonging to + self._belong_tree = None + + @classmethod + def create_call_cell(cls, cell: Cell, ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]], + func: Union[ScopedValue, str], args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None, + name: str = ""): + """ + Class method of Node. Instantiate an instance of node whose type is CallCell. A CallCell node represents an + invoking to cell-op. + + Args: + cell (Cell): An instance of Cell corresponding to this node. + ast_node (Optional[ast.AST]): An instance of ast.AST represents corresponding node in ast. + targets ([ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. + func (Optional[ScopedValue]): An instance of ScopedValue. See detail in docstring of Node class. + args ([ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. + kwargs ({str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class. + name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree. + Name of node also used as field name in network class. + """ + if args is None: + args = [] + if kwargs is None: + kwargs = {} + if isinstance(func, str): + func = ScopedValue.create_naming_value(func) + non_custom_args = Node._handle_custom_obj_in_args(args) + non_custom_kwargs = Node._handle_custom_obj_in_kwargs(kwargs) + new_targets = Node._handle_targets(targets) + if ast_node is None: + ast_node = AstModifier.create_call_assign(new_targets, func, non_custom_args, non_custom_kwargs) + return cls(NodeType.CallCell, ast_node, new_targets, func, args, kwargs, name, cell) + + @classmethod + def create_call_method(cls, ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]], + func: Union[ScopedValue, str], args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None, + name: str = ""): + """ + Class method of Node. Instantiate an instance of node whose type is CallCell. A CallCell node represents an + invoking to cell-op. + + Args: + ast_node (Optional[ast.AST]): An instance of ast.AST represents corresponding node in ast. `ast_node` should + not be None currently. + targets ([ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. + func (Optional[ScopedValue]): An instance of ScopedValue. See detail in docstring of Node class. + args ([ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. + kwargs ({str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class. + name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree. + Name of node also used as field name in network class. + """ + if args is None: + args = [] + if kwargs is None: + kwargs = {} + if isinstance(func, str): + func = ScopedValue.create_naming_value(func) + new_targets = Node._handle_targets(targets) + if ast_node is None: + raise RuntimeError("Input ast_node is None") + return cls(NodeType.CallMethod, ast_node, new_targets, func, args, kwargs, name, None) + + @classmethod + def create_call_pass_through_method(cls, ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]], + args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None, name: str = ""): + return Node.create_call_method(ast_node, targets, PASS_THROUGH_METHOD, args, kwargs, name) + + @classmethod + def create_python_node(cls, ast_node: ast.AST, name: str = "", instance=None): + """ + Class method of Node. Instantiate an instance of node whose type is Python. A Python node represents some python + statement is not supported by Rewrite or ignored by Rewrite. + + Args: + ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast. + name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree. + Name of node also used as field name in network class. + instance: An object corresponding to this node in network. + """ + return cls(NodeType.Python, ast_node, None, None, [], {}, name, instance) + + @classmethod + def create_input_node(cls, ast_node: ast.AST, arg_name: str, default: Optional[ScopedValue] = None, name: str = ""): + """ + Class method of Node. Instantiate an instance of node whose type is Input. An Input node represents input of + SymbolTree which is corresponding to parameters of forward function. + + Args: + ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast. + arg_name (str): A string represents name of parameter. + default (Optional[ScopedValue]): An instance of ScopedValue represents default value of parameter. + name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree. + Name of node also used as field name in network class. + """ + target = ScopedValue.create_naming_value(arg_name) + if default is None: + args = [] + else: + args = [default] + return cls(NodeType.Input, ast_node, [target], None, args, {}, name, None) + + @classmethod + def create_output_node(cls, ast_node: ast.AST, return_values: [str], name: str = "return"): + """ + Class method of Node. Instantiate an instance of node whose type is Output. An Output node represents output of + SymbolTree which is corresponding to return statement of forward function. + + Args: + ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast. + return_values ([str]): A list of string represents name of return values. + name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree. + Name of node also used as field name in network class. + """ + real_return_values = ScopedValue.create_name_values(return_values) + return cls(NodeType.Output, ast_node, None, ScopedValue.create_naming_value("return"), real_return_values, {}, + name, + None) + + @staticmethod + def _get_construct_arg_names(parameters): + """ + Static method of Node. Get parameters' names of the construct function. + + Args: + parameters (MappingProxyType): An ordered mapping of parameters' names to the corresponding Parameter + objects. + + Raises: + RuntimeError: Invalid parameter kind. + + Returns: + - arg_names, Parameters' names, contain parameters of types in [POSITIONAL_ONLY, POSITIONAL_OR_KEYWORD]. + - var_positional_name, Name of VAR_POSITIONAL parameters. + - var_keyword_name, Name of VAR_KEYWORD parameters. + """ + position_only_names: [str] = [] + positional_or_keyword_names: [str] = [] + var_positional_name = None + keyword_only_names: [str] = [] + var_keyword_name = None + for name, para in parameters.items(): + if para.kind == inspect.Parameter.POSITIONAL_ONLY: # parameters which appear before a '/' + position_only_names.append(name) + elif para.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: # parameters which appear before '*' or '*args' + positional_or_keyword_names.append(name) + elif para.kind == inspect.Parameter.VAR_POSITIONAL: # corresponds to a '*args' + var_positional_name = name + elif para.kind == inspect.Parameter.KEYWORD_ONLY: # parameters which appear after '*' and before '**' + keyword_only_names.append(name) + elif para.kind == inspect.Parameter.VAR_KEYWORD: # corresponds to a '**kwargs' + var_keyword_name = name + else: + raise RuntimeError("invalid para kind", para.kind) + if "self" in position_only_names: + position_only_names.remove("self") + if "self" in positional_or_keyword_names: + positional_or_keyword_names.remove("self") + names = (position_only_names, positional_or_keyword_names, var_positional_name, keyword_only_names, + var_keyword_name) + return names + + @staticmethod + def _map_args_names(names: tuple, args: [ScopedValue], kwargs: {str: ScopedValue}, + normalized_args_keys: [str], normalized_args: {str: ScopedValue}): + """ + Fill in normalized_args according to the order of parameters of construct func. + + Args: + names (tuple): Parameters' name got from construct func. + args ([ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. + kwargs ({str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class. + normalized_args ({str: ScopedValue}): The normalized args to be filled. + + Raises: + RuntimeError: Input args are invalid. + RuntimeError: Arg name already exist in kwargs. + RuntimeError: Input kwargs invalid. + """ + position_only_names, positional_or_keyword_names, var_positional_name, keyword_only_names, var_keyword_name = \ + names + for arg_index, arg in enumerate(args): + if arg_index < len(position_only_names): + arg_key = position_only_names[arg_index] + elif arg_index < len(position_only_names) + len(positional_or_keyword_names): + arg_key = positional_or_keyword_names[arg_index - len(position_only_names)] + elif var_positional_name: + arg_key = "{}_{}".format(var_positional_name, arg_index) + else: + raise RuntimeError("Input args are invalid.") + + if arg_key in kwargs.keys(): + raise RuntimeError("Arg name already exist in kwargs.") + normalized_args[arg_key] = arg + normalized_args_keys.append(arg_key) + + # add kwargs according to parameters' order + parameters_order: [str] = [] + parameters_order.extend(position_only_names) + parameters_order.extend(positional_or_keyword_names) + parameters_order.append(var_keyword_name) + parameters_order.extend(keyword_only_names) + parameters_order.append(var_keyword_name) + + sorted_kwargs = [] + var_keyword_count = len(parameters_order) + for arg_key, value in kwargs.items(): + if arg_key not in parameters_order and not var_keyword_name: + raise RuntimeError("Input kwargs invalid.") + if arg_key in parameters_order: + sorted_kwargs.append([arg_key, value, parameters_order.index(arg_key)]) + else: + sorted_kwargs.append([arg_key, value, var_keyword_count]) + var_keyword_count += 1 + + sorted_kwargs.sort(key=lambda x: x[2]) + for sorted_kwarg in sorted_kwargs: + normalized_args[sorted_kwarg[0]] = sorted_kwarg[1] + normalized_args_keys.append(sorted_kwarg[0]) + + def _get_normalized_args(self, args: [ScopedValue], kwargs: {str: ScopedValue}) -> dict: + """ + Merge args and kwargs to normalized args. + The keys of args are obtained from the construct function of type(self._instance). + + Args: + args ([ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. + kwargs ({str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class. + + Raises: + RuntimeError: Input args are invalid. + RuntimeError: Arg name already exist in kwargs. + + Returns: + The normalized args. + """ + if not args: + args = [] + if not kwargs: + kwargs = {} + normalized_args: dict = dict() + if self._instance and hasattr(type(self._instance), "construct"): + parameters = inspect.signature(type(self._instance).construct).parameters + names = Node._get_construct_arg_names(parameters) + Node._map_args_names(names, args, kwargs, self._normalized_args_keys, normalized_args) + else: + logger.debug("fail to get arg name from op, using arg_xx for args' name") + arg_temp_name, suffix = "arg", 0 + for arg in args: + arg_key = "{}_{}".format(arg_temp_name, suffix) + while arg_key in kwargs.keys() or arg_key in normalized_args.keys(): + suffix += 1 + arg_key = "{}_{}".format(arg_temp_name, suffix) + normalized_args[arg_key] = arg + self._normalized_args_keys.append(arg_key) + for arg_key, value in kwargs.items(): + normalized_args[arg_key] = value + self._normalized_args_keys.append(arg_key) + return normalized_args + + @staticmethod + def _handle_custom_obj_in_args(args: [ScopedValue]) -> [ScopedValue]: + """ + Convert CustomObjValue type argument to NamingValue type argument. + + Args: + args ([ScopedValue]): A list of instance of ScopedValue to be converted. + + Returns: + A list of instance of ScopedValue which have been converted. + """ + result = [] + for arg in args: + if not isinstance(arg, ScopedValue): + raise TypeError("arg should be ScopedValue, got: ", type(arg)) + if arg.type == ValueType.CustomObjValue: + logger.warning("custom-object exist in args, should be replace before compile") + result.append(ScopedValue.create_naming_value("custom-object", "self")) + else: + result.append(arg) + return result + + @staticmethod + def _handle_custom_obj_in_kwargs(kwargs: {str: ScopedValue}) -> {str: ScopedValue}: + """ + Convert CustomObjValue type argument to NamingValue type argument. + + Args: + kwargs ({str: ScopedValue}): A str to instance of ScopedValue dict whose value to be converted. + + Returns: + A str to instance of ScopedValue dict whose value has be converted. + """ + result: {str, ScopedValue} = {} + for arg, value in kwargs.items(): + if not isinstance(value, ScopedValue): + raise TypeError("value should be ScopedValue, got: ", type(value)) + if value.type == ValueType.CustomObjValue: + result[arg] = ScopedValue.create_naming_value("custom-object", "self") + else: + result[arg] = value + return result + + @staticmethod + def _handle_targets(targets: [Union[ScopedValue, str]]) -> [ScopedValue]: + """ + Normalize targets to be a list of ScopedValue. If target is a str, it will be converted to NamingValue type + ScopedValue. + + Args: + targets (Union[ScopedValue, str]]): A list whose element could be a ScopedValue or a str to be normalized. + + Returns: + A list of instance of ScopedValue which have been converted. + """ + if not isinstance(targets, list): + raise TypeError("targets should be list, got: ", type(targets)) + results = [] + for target in targets: + if isinstance(target, str): + results.append(ScopedValue.create_naming_value(target)) + elif isinstance(target, ScopedValue): + results.append(target) + else: + raise RuntimeError("Invalid symbol type: ", target) + return results + + def get_prev(self) -> 'Node': + """ + Get previous node of current node in source code order. + + Returns: + An instance of Node as previous node. + """ + return self._prev + + def get_next(self) -> 'Node': + """ + Get next node of current node in source code order. + + Returns: + An instance of Node as next node. + """ + return self._next + + def has_same_ast(self, node: Union['Node', ast.AST]) -> bool: + """ + Check if other node holds same ast node with self. + + Args: + node (Union[Node, ast.AST]): An instance of ast.AST or an instance of node to be compared. + + Returns: + A bool. + """ + if isinstance(node, Node): + return self.has_same_ast(node._ast_node) + if isinstance(node, ast.AST): + return id(self._ast_node) == id(node) + return False + + def get_ast(self) -> Optional[ast.AST]: + """ + Getter of _ast_node. + + Returns: + An instance of ast.AST if self._ast_node if not None else None. + """ + return self._ast_node + + def set_ast(self, ast_node: ast.AST): + """ + Setter of _ast_node. + + Args: + ast_node (ast.AST): An instance of ast.AST as new value for _ast_node. + """ + if not isinstance(ast_node, ast.AST): + raise TypeError("ast_node should be ast.AST, got: ", type(ast_node)) + self._ast_node = ast_node + + def get_belong_symbol_tree(self): + return self._belong_tree + + def set_belong_symbol_tree(self, symbol_tree): + self._belong_tree = symbol_tree + + def _sync_assign_func_to_ast(self): + """Sync func of ast.Call of ast.Assign from self._name when NodeType is CallCell.""" + if self._ast_node is None: + return + assign_ast = self._ast_node + if not isinstance(assign_ast, ast.Assign): + raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast)) + call_ast = assign_ast.value + if not isinstance(call_ast, ast.Call): + raise TypeError("call_ast should be ast.Call, got: ", type(call_ast)) + func_ast = call_ast.func + if not self._func.value: + if isinstance(func_ast, ast.Name): + func_ast.id = self._func.value + else: + call_ast.func = ast.Name(self._func.value, ast.Store()) + else: + if isinstance(func_ast, ast.Attribute): + func_value = func_ast.value + if not isinstance(func_value, ast.Name): + raise RuntimeError("Only support ast.Name as value of attribute ", type(func_ast.value)) + func_value.id = self._func.scope + func_ast.attr = self._func.value + else: + call_ast.func = ast.Attribute(ast.Name(self._func.scope, ast.Load()), self._func.value, ast.Store()) + ast.fix_missing_locations(assign_ast) + + def _sync_assign_targets_to_ast(self): + """Sync targets of ast.Assign from self._targets when NodeType is CallCell or CallMethod.""" + if self._ast_node is None: + return + assign_ast = self._ast_node + if not isinstance(assign_ast, ast.Assign): + raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast)) + # update targets + targets_ast = assign_ast.targets + if len(self._targets) != len(targets_ast): + raise RuntimeError("self._targets should have targets_ast same length") + for i in range(0, len(self._targets)): + target = self._targets[i] + target_ast = targets_ast[i] + if not isinstance(target_ast, ast.Name): + raise TypeError("target_ast should be ast.Name, got: ", type(target_ast)) + target_ast.id = target.value + ast.fix_missing_locations(assign_ast) + + def _sync_call_cell_args_to_ast(self): + """Sync args of ast.Cell of ast.Assign from self._normalized_args when NodeType is CallCell.""" + if self._ast_node is None: + return + assign_ast = self._ast_node + if not isinstance(assign_ast, ast.Assign): + raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast)) + assign_value = assign_ast.value + if not isinstance(assign_value, ast.Call): + return + keywords_ast = assign_value.keywords + args_ast = assign_value.args + if len(self._normalized_args_keys) != (len(keywords_ast) + len(args_ast)): + raise RuntimeError("ast keywords plus args len is not equal to self._normalized_args value") + + for arg_index in range(self._args_num): + arg_ast = args_ast[arg_index] + AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[arg_index]), arg_ast) + + # the order of kwargs may not the same as that in keywords_ast + keyword_map_index = {} + for index, keyword_ast in enumerate(keywords_ast): + keyword_map_index[keyword_ast.arg] = index + for keyword_index in range(self._kwargs_num): + key = self._normalized_args_keys[keyword_index + self._args_num] + AstModifier.update_arg_value(self._normalized_args.get(key), keywords_ast[keyword_map_index[key]].value) + + def _sync_call_method_args_to_ast(self): + """Sync args of ast.Cell of ast.Assign from self._normalized_args when NodeType is CallMethod.""" + if self._ast_node is None: + return + assign_ast = self._ast_node + if not isinstance(assign_ast, ast.Assign): + raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast)) + assign_value = assign_ast.value + if self._func == PASS_THROUGH_METHOD: + if isinstance(assign_value, ast.Name): + if len(self._normalized_args_keys) != 1: + raise RuntimeError("self._normalized_args_keys should have 1 elements") + arg = self._normalized_args.get(self._normalized_args_keys[0]) + if arg.type != ValueType.NamingValue: + raise RuntimeError("arg.type should equal to ValueType.NamingValue") + if arg.scope != "": + raise RuntimeError("arg.scope should be empty") + assign_value.id = arg.value + elif isinstance(assign_value, ast.Attribute): + if len(self._normalized_args_keys) != 1: + raise RuntimeError("self._normalized_args_keys should have 1 elements") + arg = self._normalized_args.get(self._normalized_args_keys[0]) + if arg.type != ValueType.NamingValue: + raise RuntimeError("arg.type should equal to ValueType.NamingValue") + assign_value.attr = arg.value + assign_value_value = assign_value.value + if not isinstance(assign_value_value, ast.Name): + raise RuntimeError("Only support ast.Name as value of attribute ", type(assign_value_value)) + assign_value_value.id = arg.scope + else: + if len(self._normalized_args_keys) != 1: + raise RuntimeError("self._normalized_args_keys should have 1 elements") + arg = self._normalized_args.get(self._normalized_args_keys[0]) + if arg.type not in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue): + raise RuntimeError("arg should be an IntValue, FloatValue or StringValue") + if arg.scope != "": + raise RuntimeError("arg.scope should be empty") + assign_value.value = arg.value + else: + raise RuntimeError("Only support pass_through method as call_method now, ", self._func.value) + + def _sync_return_node_to_ast(self): + """Sync return value of ast.Return from self._normalized_args when NodeType is Output.""" + if self._ast_node is None: + return + return_ast = self._ast_node + if not isinstance(return_ast, ast.Return): + raise TypeError("return_ast should be ast.Return, got: ", type(return_ast)) + # update args + return_value_ast = return_ast.value + if isinstance(return_value_ast, ast.Name): + if len(self._normalized_args_keys) != 1: + raise RuntimeError("self._normalized_args_keys should have 1 elements") + return_value_ast.id = self._normalized_args.get(self._normalized_args_keys[0]).value + elif isinstance(return_value_ast, ast.Tuple): + elements = return_value_ast.elts + if len(self._normalized_args.values()) != len(elements): + raise RuntimeError("self._normalized_args.values() should have elements same length") + for elt_index, elt in enumerate(elements): + if not isinstance(elt, ast.Name): + raise RuntimeError("Only support ast.Name as return value: ", elt) + arg = self._normalized_args.get(self._normalized_args_keys[elt_index]) + if not isinstance(arg, ScopedValue): + raise TypeError("arg should be ScopedValue, got: ", type(arg)) + elt.id = arg.value + else: + raise RuntimeError("Unsupported return value type: ", return_value_ast) + ast.fix_missing_locations(return_ast) + + def isolate(self): + """Link prev node to next node and isolate node from source code order list.""" + origin_prev: Optional[Node] = self._prev + origin_next: Optional[Node] = self._next + if origin_prev is not None: + origin_prev._next = origin_next + if origin_next is not None: + origin_next._prev = origin_prev + self._prev = None + self._next = None + + def insert_before(self, node: 'Node'): + """ + Insert a node before current node in source code list. Note that topological order is not determined here. + + Args: + node (Node): An instance of node to be inserted in. + """ + node.isolate() + origin_prev: Optional[Node] = self._prev + if origin_prev is not None: + origin_prev._next = node + node._prev = origin_prev + node._next = self + self._prev = node + + def insert_after(self, node: 'Node'): + """ + Insert a node after current node in source code list. Note that topological order is not determined here. + + Args: + node (Node): An instance of node to be inserted in. + """ + node.isolate() + origin_next: Optional[Node] = self._next + self._next = node + node._prev = self + node._next = origin_next + if origin_next is not None: + origin_next._prev = node + + def get_inputs(self) -> ['Node']: + """ + Getter of _inputs which represents input nodes of current node in topological order. + + Returns: + A list of instances of Node as input nodes. + """ + return self._inputs + + def set_inputs(self, inputs: ['Node']): + """ + Setter of _inputs which represents input nodes of current node in topological order. + + + Args: + inputs ([Node]): A list of instances of Node as new input nodes. + """ + self._inputs = inputs + + def get_targets(self) -> [ScopedValue]: + """ + Getter of _targets. + + - When node_type of current node is CallCell or CallPrimitive or CallMethod or Tree, `targets` are strings + represents invoke result of the cell-op or primitive-op or function-call which are corresponding to targets of + ast.Assign. + - When node_type of current node is Input, `targets` should have only one element which is a string represents + name of parameter of function. + - When node_type of current node is Python or Output, `targets` are don't-care. + + Returns: + A list of instances of ScopedValue as targets of node. + """ + return self._targets + + def set_targets(self, targets: [ScopedValue]): + """ + Setter of _targets. + + Note: + This interface can only be called before node been inserted into symbol-tree because target will be unique + while insert into symbol-tree, in other word, set_targets is not a user-interface. + + When `_targets` is updated, corresponding ast node would be updated also. + + When node_type of current node is CallCell or CallPrimitive or CallMethod or Tree, `targets` are strings + represents invoke result of the cell-op or primitive-op or function-call which are corresponding to targets + of ast.Assign. + + When node_type of current node is Input, `targets` should have only one element which is a string represents + name of parameter of function. + + When node_type of current node is Python or Output, `targets` are don't-care. + + Args: + targets ([ScopedValue]): A list of instances of ScopedValue as new targets. + """ + self._targets = targets + if self._node_type in (NodeType.CallCell, NodeType.CallMethod): + self._sync_assign_targets_to_ast() + + def get_func(self) -> ScopedValue: + """ + Getter of `_func`. See detail in docstring of Node class for meaning of func. + + Returns: + An instance of ScopedValue. + """ + return self._func + + def set_func(self, func: ScopedValue): + """ + Setter of `_func`. See detail in docstring of Node class for meaning of func. + + Note: + When `_func` is updated, corresponding ast node would be updated also. + + Args: + func (ScopedValue): An instance of ScopedValue as new func. + """ + self._func = func + if self._node_type == NodeType.CallCell: + self._sync_assign_func_to_ast() + + def get_name(self) -> str: + """ + Getter of `_name`. + + Returns: + A str represents name of node. + """ + return self._name + + def set_name(self, name: str): + """ + Setter of `_name`. + + Args: + name (str): A str as new name of node. + """ + self._name = name + + def get_node_type(self) -> NodeType: + """ + Get the node_type of current node. + + Returns: + A NodeType as node_type of node. + """ + return self._node_type + + def get_instance_type(self) -> type: + """ + Get the instance_type of current node. + When node_type of current node is CallCell, instance_type is type of cell-op. + When node_type of current node is CallPrimitive, instance_type is type of primitive-op. + When node_type of current node is Tree, instance_type is type of network-cell. + When node_type of current node is Python, Input, Output or CallMethod, instance_type should be NoneType + + Returns: + A type. + """ + return type(self._instance) + + def get_instance(self): + """ + Get the instance of current node. + When node_type of current node is CallCell, instance is an instance of Cell. + When node_type of current node is CallPrimitive, instance is an instance of primitive. + When node_type of current node is Tree, instance is an instance of network-cell. + When node_type of current node is Python, Input, Output or CallMethod, instance should be None + + Returns: + A object. + """ + return self._instance + + def _sync_arg(self): + """Sync _normalized_args to corresponding ast node when updated.""" + if self._node_type == NodeType.CallCell: + self._sync_call_cell_args_to_ast() + elif self._node_type == NodeType.Output: + self._sync_return_node_to_ast() + elif self._node_type == NodeType.CallMethod: + self._sync_call_method_args_to_ast() + + def set_arg_by_node(self, arg_idx: int, node: 'Node', out_idx: Optional[int] = None): + """ + Set argument by another Node. + Note that when _normalized_args is updated, corresponding ast node would be updated also. + + Args: + arg_idx (int): Indicate which input being modified. + node (Node): Node as new input. Can be a node or name of node. + out_idx (Optional[int]): Indicate which output of 'node' as new argument. Default is None which means use + first output of 'node_to_link' as new input. + + Raises: + RuntimeError: If 'arg_idx' is out of range. + RuntimeError: If 'node' has multi-outputs while 'out_idx' is None or 'out_idx' is not offered. + """ + if not isinstance(node, Node): + raise TypeError("node should be Node, got: ", type(node)) + if arg_idx >= self._args_num or arg_idx < 0: + raise RuntimeError("arg_idx out of range: ", arg_idx) + if out_idx is None: + if len(node._targets) != 1: + raise RuntimeError("node should has one output when out_idx is not provided") + out_idx = 0 + if out_idx >= len(node._targets): + raise RuntimeError("out_idx out of range: ", out_idx) + new_arg = node._targets[out_idx] + self._normalized_args[self._normalized_args_keys[arg_idx]] = new_arg + self._sync_arg() + + def set_arg(self, arg: Union[ScopedValue, str], index: int) -> (ScopedValue, ScopedValue): + """ + Set argument of 'node'. + Note that when _normalized_args is updated, corresponding ast node would be updated also. + + Args: + index (int): Indicate which input being modified. + arg (Union[ScopedValue, str]): New argument to been set. + + Raises: + RuntimeError: If 'index' is out of range. + """ + if index < 0 or index >= self._args_num: + raise RuntimeError("index error", index) + if isinstance(arg, str): + arg = ScopedValue.create_naming_value(arg) + old_arg = self._normalized_args.get(self._normalized_args_keys[index]) + self._normalized_args[self._normalized_args_keys[index]] = arg + self._sync_arg() + return arg, old_arg + + def set_args(self, args: [ScopedValue]): + """ + Set arguments of 'node'. + Note that when _normalized_args is updated, corresponding ast node would be updated also. + + Args: + args ([ScopedValue]): New arguments to been set. + + Raises: + TypeError: Element of new argument is not an instance of ScopedValue. + """ + if len(args) != self._args_num: + raise RuntimeError("Length of args should be equal to _args_num, ", len(args), " vs ", self._args_num) + for arg_index, arg in enumerate(args): + if not isinstance(arg, ScopedValue): + raise TypeError("arg should be ScopedValue, got: ", type(arg)) + self._normalized_args[self._normalized_args_keys[arg_index]] = arg + self._sync_arg() + + def set_kwargs(self, kwargs: {str: ScopedValue}): + """ + Set keywords arguments of 'node'. + Note that when _normalized_args is updated, corresponding ast node would be updated also. + + Args: + kwargs ({str: ScopedValue}): New arguments to been set. + + Raises: + TypeError: Value of new argument is not an instance of ScopedValue. + RuntimeError: Length of new arguments is not equal to length of old arguments. + """ + if len(kwargs) != self._kwargs_num: + raise RuntimeError("Length of kwargs should be equal to _kwargs_num, ", len(kwargs), " vs ", + self._kwargs_num) + for key, arg in kwargs.items(): + if key not in self._normalized_args.keys() or key not in self._normalized_args_keys: + raise RuntimeError("Input key is not exist, ", key) + if not isinstance(arg, ScopedValue): + raise TypeError("arg should be ScopedValue, got: ", type(arg)) + self._normalized_args[key] = arg + self._sync_arg() + + def set_kwarg(self, key: str, arg: ScopedValue): + """ + Set keyword argument of 'node'. + Note that when _normalized_args is updated, corresponding ast node would be updated also. + + Args: + key (str): A str represents key of new argument. + arg (ScopedValue): An instance of ScopedValue represents argument. + + Raises: + RuntimeError: If 'key' is not in original kwargs' keys. + """ + if key not in self._normalized_args_keys[self._args_num:] or key not in self._normalized_args.keys(): + raise RuntimeError("Input key is not exist, ", key) + self._normalized_args[key] = arg + self._sync_arg() + + def get_args(self): + """ + Get the arguments of current node. + When node_type of current node is CallCell, CallPrimitive or Tree, arguments are corresponding to args of + ast.Call which represents arguments to invoke cell-op's forward method or primitive-op's `call()` method. + When node_type of current node is Input, arguments represents default-value of argument of function. + When node_type of current node is Output, arguments represents return values. + When node_type of current node is Python, arguments are don't-care. + + Returns: + A list of instances of ScopedValue. + """ + args = [] + for arg_index in range(self._args_num): + args.append(self._normalized_args.get(self._normalized_args_keys[arg_index])) + return args + + def get_kwargs(self): + """ + Get the keyword arguments of current node. + When node_type of current node is CallCell, CallPrimitive or Tree, keyword arguments are corresponding to kwargs + of ast.Call which represents arguments to invoke cell-op's forward method or primitive-op's `call()` method. + When node_type of current node is Python, Input or Output, keyword arguments are don't-care. + + Returns: + A dict of str to instance of ScopedValue. + """ + kwargs: {str, ScopedValue} = {} + for arg_index in range(self._args_num, self._args_num + self._kwargs_num): + key = self._normalized_args_keys[arg_index] + kwargs[key] = self._normalized_args.get(key) + return kwargs + + def get_normalized_args(self) -> {str: ScopedValue}: + """ + Get the normalized keyword arguments of current node. + Normalized arguments combine arguments and keyword arguments into keyword arguments by using parameter name as + key of arguments. + + Returns: + A dict of str to instance of ScopedValue. + """ + output = {} + for key in self._normalized_args_keys: + output[key] = self._normalized_args.get(key) + return output + + def set_normalized_args(self, args: {str, ScopedValue}): + """ + Set the normalized keyword arguments of current node. + Normalized arguments combine arguments and keyword arguments into keyword arguments by using parameter name as + key of arguments. + + Args: + args ({str, ScopedValue}): A dict of str to instance of ScopedValue represents new normalized_args. + """ + if len(args.values()) != len(self._normalized_args_keys): + raise RuntimeError("Length of args.values() should be equal to length of _normalized_args_keys, ", + len(args.values()), " vs ", len(self._normalized_args_keys)) + for key, arg in args.items(): + self._normalized_args[key] = arg + self._sync_arg() + + def set_attribute(self, key: str, value): + """ + Set attribute of current node. + + Args: + key (str): Key of new attribute. + value (object): Value of new attribute. + """ + self._attribute[key] = value + + def set_attributes(self, attributes): + """ + Set attributes of current node. + + Args: + attributes (dict): A dict represents new attributes. + """ + self._attribute = attributes + + def get_attributes(self): + """ + Get all attributes of current node. + + Returns: + A dict of str to instance of object as attributes. + """ + return self._attribute + + def get_attribute(self, key: str): + """ + Get attribute of current node by key. + + Args: + key (str): A str represents key of attribute you want to get. + + Returns: + A object as attribute. + """ + return self._attribute.get(key) + + @staticmethod + def _get_cell_or_prim_op_attribute(obj) -> dict: + """ + Find attributes of cell-op or primitive-op. + + Args: + obj: A cell-op or a primitive-op. + + Returns: + A dict represents attributes of input 'obj'. + """ + attributes = {} + if obj is None: + return attributes + for k, v in obj.__dict__.items(): + if k.startswith("_"): + continue + attributes[k] = v + attributes["cls"] = obj.__class__ + return attributes + + +class TreeNode(Node): + """Tree type Node who holds a handler of SymbolTree.""" + + def __init__(self, tree, ast_node: ast.AST, targets: [ScopedValue], func: Union[ScopedValue, str], + args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance): + """ + Constructor of Node. Rewrite recommend to invoking class method of Node to instantiate an instance of Node such + as `create_call_cell`, `create_call_method`, `create_python_node`, `create_input_node` and `create_output_node`, + etc. rather than invoking constructor of Node directly. + + Args: + tree: An instance of SymbolTree represents a handler of sub-symbol-tree. + ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast. + targets ([ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. + func (Optional[ScopedValue]): An instance of ScopedValue. See detail in docstring of Node class. + args ([ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. + kwargs ({str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class. + name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree. + Name of node also used as field name in network class. + instance: Object in network corresponding to this node. + """ + if isinstance(func, str): + func = ScopedValue.create_naming_value(func) + super().__init__(NodeType.Tree, ast_node, targets, func, args, kwargs, name, instance) + self.symbol_tree = tree + + @classmethod + def create_tree_node(cls, tree, ast_node: ast.AST, targets: [ScopedValue], func: Union[ScopedValue, str], + args: [ScopedValue], kwargs: {str: ScopedValue}, name: str = "", instance=None): + """ + Class method of TreeNode. Instantiate an instance of node whose type is Tree. A Tree node represents an invoking + to sub-network. + + Args: + tree: An instance of SymbolTree represents a handler of sub-symbol-tree. + ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast. + targets ([ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. + func (Optional[ScopedValue]): An instance of ScopedValue. See detail in docstring of Node class. + args ([ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. + kwargs ({str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class. + name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree. + Name of node also used as field name in network class. + instance: Object in network corresponding to this node. + """ + return cls(tree, ast_node, targets, func, args, kwargs, name, instance) diff --git a/mindspore/python/mindspore/rewrite/parser.py b/mindspore/python/mindspore/rewrite/parser.py new file mode 100644 index 00000000000..a9941aea795 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/parser.py @@ -0,0 +1,45 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Base class of parser.""" +import abc +import ast + +from .symbol_tree import SymbolTree + + +class Parser(abc.ABC): + """ + DFS into a ast_node until add node into SymbolTree + """ + + def target(self) -> type: + """ + Get type of ast which could be accepted by current parser. + + Returns: + A type of ast. + """ + return type(None) + + @abc.abstractmethod + def process(self, stree: SymbolTree, node: ast.AST): + """ + Parse input ast node and add parse result into SymbolTree. + + Args: + stree (SymbolTree): current symbol_tree + node (ast.AST): node who is tried to be parsed + """ + raise NotImplementedError diff --git a/mindspore/python/mindspore/rewrite/parser_register.py b/mindspore/python/mindspore/rewrite/parser_register.py new file mode 100644 index 00000000000..a1cfa28c178 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/parser_register.py @@ -0,0 +1,86 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Parser register.""" +from typing import Optional + +from .parser import Parser + + +class ParserRegister: + """Parser register.""" + + def __init__(self): + self._parsers: dict = {} + + @classmethod + def instance(cls) -> 'ParserRegister': + """ + Get singleton of ParserRegister. + + Returns: + An instance of ParserRegister. + """ + if not hasattr(ParserRegister, "_instance"): + ParserRegister._instance = ParserRegister() + return ParserRegister._instance + + @staticmethod + def reg_parser(parser: Parser): + """ + Register a 'parser' to current ParserRegister. + + Args: + parser (Parser): An instance of Parser to be registered. + """ + if isinstance(parser, Parser): + ParserRegister.instance()._parsers[parser.target()] = parser + + def get_parser(self, ast_type: type) -> Optional[Parser]: + """ + Get parser from current ParserRegister by type of ast. + + Args: + ast_type (type): An type of ast which want to be parsed. + + Returns: + An instance of Parser if there existing suitable parser in current ParserRegister else None. + """ + return self._parsers.get(ast_type) + + def get_parsers(self) -> [Parser]: + """ + Get all parsers registered in current ParserRegister. + + Returns: + An list of instances of Parser. + """ + return self._parsers + + +class ParserRegistry: + """Parser registry.""" + + def __init__(self, parser: Parser): + ParserRegister.instance().reg_parser(parser) + + +def reg_parser(parser: Parser): + """ + A global method for registering parser into ParserRegister singleton. + + Args: + parser (Parser): An instance of Parser to be registered. + """ + return ParserRegistry(parser) diff --git a/mindspore/python/mindspore/rewrite/parsers/__init__.py b/mindspore/python/mindspore/rewrite/parsers/__init__.py new file mode 100644 index 00000000000..018c134e030 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/parsers/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Parsers for resolve ast to SymbolTree +""" diff --git a/mindspore/python/mindspore/rewrite/parsers/arguments_parser.py b/mindspore/python/mindspore/rewrite/parsers/arguments_parser.py new file mode 100644 index 00000000000..e34198a4fe9 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/parsers/arguments_parser.py @@ -0,0 +1,61 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Parse ast.arguments to input-node of SymbolTree.""" +import ast + +from ..parser import Parser +from ..parser_register import reg_parser +from ..symbol_tree import SymbolTree + + +class ArgumentsParser(Parser): + """Parse ast.arguments to input-node of SymbolTree.""" + + def target(self): + """Parse target type""" + return ast.arguments + + def process(self, stree: SymbolTree, node: ast.arguments): + """ + Parse ast.arguments and create input-node to stree. + + Args: + stree (SymbolTree): symbol tree under parsing. + node (ast.arguments): argument node in construct. + + Raises: + RuntimeError: Types of node.args elements are not ast.arg. + """ + if hasattr(node, "posonlyargs"): + stree.try_append_python_node(node, node.posonlyargs) + + for arg in node.args: + if not isinstance(arg, ast.arg): + raise RuntimeError("Unsupported ast type in arguments arg: ", arg) + stree.append_input_node(arg.arg) + + if hasattr(node, "vararg"): + stree.try_append_python_node(node, node.vararg) + if hasattr(node, "kwonlyargs"): + stree.try_append_python_node(node, node.kwonlyargs) + if hasattr(node, "kw_defaults"): + stree.try_append_python_node(node, node.kw_defaults) + if hasattr(node, "kwarg"): + stree.try_append_python_node(node, node.kwarg) + if hasattr(node, "defaults"): + stree.try_append_python_node(node, node.defaults) + + +g_arguments_parser = reg_parser(ArgumentsParser()) diff --git a/mindspore/python/mindspore/rewrite/parsers/assign_parser.py b/mindspore/python/mindspore/rewrite/parsers/assign_parser.py new file mode 100644 index 00000000000..a069f85a18f --- /dev/null +++ b/mindspore/python/mindspore/rewrite/parsers/assign_parser.py @@ -0,0 +1,252 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Parse ast.Assign in construct function to node of SymbolTree.""" +import ast +import astunparse + +from mindspore import log as logger +from ..symbol_tree import SymbolTree +from ..node import Node, TreeNode +from ..parser import Parser +from ..parser_register import reg_parser +from ..api.scoped_value import ScopedValue +from ..symbol_tree_builder import SymbolTreeBuilder + + +class AssignParser(Parser): + """Parse ast.Assign in construct function to node of SymbolTree.""" + + def target(self): + """Parse target type.""" + return ast.Assign + + @staticmethod + def _create_scopedvalue_from_tuple_ast(node: ast.Tuple) -> ScopedValue: + """ + Create ScopedValue from a tuple ast node. + + Args: + node (ast.Tuple): A tuple node. + + Returns: + An instance of ScopedValue. + + Raises: + RuntimeError: Only support ast.Constant as elts of ast.Tuple. + """ + tuple_elts = node.elts + tuple_values = [] + for tuple_elt in tuple_elts: + if not isinstance(tuple_elt, ast.Constant): + raise RuntimeError("Only support ast.Constant as elts of ast.Tuple.") + tuple_values.append(tuple_elt.value) + return ScopedValue.create_variable_value(tuple(tuple_values)) + + @staticmethod + def _create_scopedvalue(node: ast.expr) -> ScopedValue: + """ + Create ScopedValue from an ast node. + + Args: + node (ast.expr): An ast node. + + Returns: + An instance of ScopedValue. + + Raises: + RuntimeError: Value of target of ast.Assign should be an ast.Name when target is an ast.Attribute. + RuntimeError: Type of input node is unsupported. + """ + if isinstance(node, ast.Name): + return ScopedValue.create_naming_value(node.id) + if isinstance(node, ast.Attribute): + scope = node.value + if not isinstance(scope, ast.Name): + raise RuntimeError("value of target of ast.Assign should be a ast.Name when target is a ast.Attribute.") + return ScopedValue.create_naming_value(node.attr, scope.id) + if isinstance(node, ast.Tuple): + return AssignParser._create_scopedvalue_from_tuple_ast(node) + if isinstance(node, ast.Constant): + return ScopedValue.create_variable_value(node.value) + raise RuntimeError("Unsupported ast type to argument:", node) + + @staticmethod + def _get_func_name(ast_node: ast.Call) -> str: + """ + Get the func name from ast.Call. + + Args: + ast_node (ast.Call): Input ast.Call node. + + Returns: + Func name. + + Raises: + RuntimeError: Func of input ast node is not ast.Name or ast.Attribute. + """ + func = ast_node.func + if isinstance(func, ast.Name): + return func.id + if isinstance(func, ast.Attribute): + return func.attr + raise RuntimeError("FuncValue is should be Name or a Attribute:", astunparse.unparse(func)) + + @staticmethod + def _get_func_scope(ast_node: ast.Call) -> str: + """ + Get the func scope from ast.Call. + + Args: + ast_node (ast.Call): Input ast.Call node. + + Returns: + Func scope. + + Raises: + RuntimeError: FuncValue is not an ast.Name when func is an ast.Attribute. + RuntimeError: Func of input ast node is not ast.Name or ast.Attribute. + """ + func = ast_node.func + if isinstance(func, ast.Name): + return "" + if isinstance(func, ast.Attribute): + value = func.value + if not isinstance(value, ast.Name): + raise RuntimeError("FuncValue is should be Name:", ast.dump(func)) + return value.id + raise RuntimeError("FuncValue is should be Name or a Attribute:", ast.dump(func)) + + @staticmethod + def _get_symbol_object(symbol_name, origin_net): + """ + Get the func scope from ast.Call. + + Args: + symbol_name (str): Func name. + origin_net ([nn.Cell]): Network instance. + + Returns: + Symbol Object. + """ + var_dict = origin_net.__dict__ + for key, value in var_dict["_cells"].items(): + if key == symbol_name: + return value + + for key, value in var_dict["_primitives"].items(): + if key == symbol_name: + return value + return None + + @staticmethod + def _create_kwargs(keywords: [ast.keyword]) -> {str, ScopedValue}: + """ + Transfer ast.Call keywords to a dict of ScopedValue when creating a symbol tree node. + + Args: + keywords ([ast.keyword]): Keywords of ast.Call node. + + Returns: + A dict of ScopedValue. + """ + results = {} + for keyword in keywords: + results[keyword.arg] = AssignParser._create_scopedvalue(keyword.value) + return results + + @staticmethod + def _convert_ast_call_to_node(ast_node: ast.Call, father_ast_node: ast.Assign, stree: SymbolTree) -> Node: + """ + Convert ast.Call to a symbol tree node. + + Args: + ast_node ([ast.Call]): An ast.Call of assign node in construct. + father_ast_node ([ast.Assign]): Assign node in construct. + stree ([SymbolTree]): Symbol Tree under parsing. + + Returns: + An instance of Node in Symbol Tree. + + Raises: + RuntimeError: kwargs in construct function assign is unsupported. + """ + target = AssignParser._create_scopedvalue(father_ast_node.targets[0]) + func_name = AssignParser._get_func_name(ast_node) + if func_name is None or func_name == "": + raise RuntimeError("function name not exist") + func_scope = AssignParser._get_func_scope(ast_node) + func = ScopedValue.create_naming_value(func_name, func_scope) + call_args = [AssignParser._create_scopedvalue(arg) for arg in ast_node.args] + call_kwargs = AssignParser._create_kwargs(ast_node.keywords) + if ast_node.keywords: + raise RuntimeError("kwargs in construct function assign is unsupported.") + + obj = AssignParser._get_symbol_object(func_name, stree.get_origin_network()) + # need check if node is a callmethod, like: x = len(x) + # need check if node is a callprimitive, like: x = x * 5 + is_sub_tree = False + if is_sub_tree: + stb = SymbolTreeBuilder(obj) + new_stree = stb.build() + return TreeNode(new_stree, father_ast_node, [target], func, call_args, call_kwargs, func_name, + new_stree.get_origin_network()) + return Node.create_call_cell(obj, father_ast_node, [target], func, call_args, call_kwargs, func_name) + + def process(self, stree: SymbolTree, node: ast.Assign): + """ + Parse ast.Assign and create a node in symbol tree. + Will create node when value of ast.Assign is in [ast.Call, ast.Name, ast.Constant, ast.Attribute]. + Will create python node when value of ast.Assign is in + [ast.BinOp, ast.BoolOp, ast.Subscript, ast.List, ast.Tuple, ast.Dict]. + Other value types are not supported. + + Args: + stree ([SymbolTree]): Symbol Tree under parsing. + node ([ast.Assign]): An ast.Assign node. + + Raises: + RuntimeError: Only support one target in assign now. + RuntimeError: Unsupported node type in construct function. + """ + targets = node.targets + if len(targets) != 1: + raise RuntimeError("Only support one target in assign now") + value = node.value + if isinstance(value, ast.Call): + node_ = AssignParser._convert_ast_call_to_node(value, node, stree) + stree.append_origin_field(node_) + elif isinstance(value, (ast.BinOp, ast.BoolOp, ast.Subscript)): + logger.warning(f"ops-call({astunparse.unparse(node)}) in assign will be supported in near feature, " + f"ignored as a python node now") + stree.try_append_python_node(node, node) + elif isinstance(value, (ast.Name, ast.Constant, ast.Attribute, ast.Num, ast.NameConstant, ast.Bytes, ast.Str)): + if isinstance(value, ast.Name): + node_name = "name_assign" + elif isinstance(value, ast.Constant): + node_name = "constant_assign" + else: + node_name = "attribute_assign" + target = AssignParser._create_scopedvalue(node.targets[0]) + call_args = [AssignParser._create_scopedvalue(value)] + node_ = Node.create_call_pass_through_method(node, [target], call_args, {}, node_name) + stree.append_origin_field(node_) + elif isinstance(value, (ast.List, ast.Tuple, ast.Dict)): + # add these as callmethod node if necessary + stree.try_append_python_node(node, node) + else: + raise RuntimeError(f"Unsupported statement({astunparse.unparse(node)}) in construct function!") + + +g_assign_parser = reg_parser(AssignParser()) diff --git a/mindspore/python/mindspore/rewrite/parsers/class_def_parser.py b/mindspore/python/mindspore/rewrite/parsers/class_def_parser.py new file mode 100644 index 00000000000..b86694b1481 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/parsers/class_def_parser.py @@ -0,0 +1,170 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Parse ast.ClassDef which is subclass of Cell to SymbolTree.""" +import ast + +from mindspore import log as logger +from ..symbol_tree import SymbolTree +from ..parser import Parser +from ..parser_register import ParserRegister, reg_parser +from ..api.scoped_value import ScopedValue +from ..ast_modifier import AstModifier + + +class ClassDefParser(Parser): + """Parse ast.ClassDef which is subclass of Cell to SymbolTree.""" + + def target(self): + """Parse target type""" + return ast.ClassDef + + @staticmethod + def _process_init_func_ast(init_ast: ast.FunctionDef, ori_cls_name: str, opt_cls_name: str): + """Process init func""" + super_index = ClassDefParser._modify_super_expr_of_init_func(init_ast, ori_cls_name, opt_cls_name) + ClassDefParser._modify_arguments_of_init_func(init_ast) + ClassDefParser._replace_ori_field_of_init_func(init_ast.body, super_index) + ClassDefParser._insert_handler_to_init_func(init_ast, super_index) + + @staticmethod + def _modify_super_expr_of_init_func(ast_init_fn: ast.FunctionDef, ori_cls_name: str, opt_cls_name: str) -> int: + """Modify network name in super(XXnet).__init__()""" + if not ast_init_fn.body: + return -1 + super_index = -1 + super_call_args = None + while True: + super_index += 1 + expr = ast_init_fn.body[super_index] + if not isinstance(expr, ast.Expr): + continue + expr_value = expr.value + if not isinstance(expr_value, ast.Call): + continue + expr_value_func = expr_value.func + if not isinstance(expr_value_func, ast.Attribute): + continue + expr_value_func_value = expr_value_func.value + if expr_value_func.attr != "__init__" or not isinstance(expr_value_func_value, ast.Call): + continue + expr_value_func_value_func = expr_value_func_value.func + if not isinstance(expr_value_func_value_func, ast.Name) or expr_value_func_value_func.id != "super": + continue + super_call_args = expr_value_func_value.args + break + if super_call_args is None or not isinstance(super_call_args, list) or len(super_call_args) != 2: + return super_index + super_call_arg = super_call_args[0] + if super_call_arg.id != ori_cls_name: + raise RuntimeError("super_call_arg.id should equal to ori_cls_name") + super_call_arg.id = opt_cls_name + return super_index + + @staticmethod + def _modify_arguments_of_init_func(ast_init_fn: ast.FunctionDef): + """Replace init function input parameters to self and global_vars.""" + arg_self = ast.arg(arg="self", annotation="") + arg_global_vars = ast.arg(arg="global_vars", annotation="") + ast_init_fn.args = ast.arguments(args=[arg_self, arg_global_vars], posonlyargs=[], kwonlyargs=[], + kw_defaults=[], defaults=[], vararg=None, kwarg=None) + ast.fix_missing_locations(ast_init_fn) + + @staticmethod + def _replace_ori_field_of_init_func(bodies: [], super_index: int): + """ + Replace original field in init func to self.XX = getattr(self._handler, "XX"). + Only keep following two kinds of ast nodes in bodies right now: + 1. Ast.If and test is self.XX. + 2. Ast.Assign and target is self.XX. + + Args: + bodies ([]): bodied of init ast.FunctionDef. + super_index (int): index of super().__init__() in bodies. + + Raises: + RuntimeError: Not support multi-targets in assign. + RuntimeError: Only support target.value in [ast.Name] in assign node. + """ + body_index_to_be_deleted = [] + for body_index, body in enumerate(bodies): + if body_index == super_index: + continue # ignoring super.__init__() + if isinstance(body, ast.If) and isinstance(body.test, ast.Attribute) \ + and isinstance(body.test.value, ast.Name) and body.test.value.id == 'self': + ClassDefParser._replace_ori_field_of_init_func(body.body, -1) + ClassDefParser._replace_ori_field_of_init_func(body.orelse, -1) + continue + if not isinstance(body, ast.Assign): # if not assign node, delete + body_index_to_be_deleted.append(body_index) + continue + if len(body.targets) != 1: + raise RuntimeError("Not support multi-targets in assign now!") + target = body.targets[0] + if not isinstance(target, ast.Attribute): # only keep class member + body_index_to_be_deleted.append(body_index) + continue + if not isinstance(target.value, ast.Name): + raise RuntimeError("Only support target.value in ast.Name now!") + target_value: ast.Name = target.value + if target_value.id != "self": + body_index_to_be_deleted.append(body_index) + continue + field_name = target.attr + body.value = ast.Call(ast.Name('getattr', ast.Load()), + [ast.Attribute(ast.Name('self', ast.Load()), '_handler', ast.Load()), + ast.Constant(value=field_name, kind=None)], []) + for counter, index in enumerate(body_index_to_be_deleted): + bodies.pop(index - counter) + + @staticmethod + def _insert_handler_to_init_func(ast_init_fn: ast.FunctionDef, super_index): + """Insert 'self._handler = global_vars.get('handler')' to init ast.FunctionDef.body""" + if super_index == -1: + super_index = 0 + AstModifier.insert_assign_to_function(ast_init_fn, [ScopedValue.create_naming_value("_handler", "self")], + ScopedValue.create_naming_value("get", "global_vars"), + [ScopedValue.create_variable_value("handler")], None, + ast_init_fn.body[super_index], False) + + def process(self, stree: SymbolTree, node: ast.ClassDef): + """ + Parse init and construct in ast.ClassDef. + + Args: + stree ([SymbolTree]): Symbol Tree under parsing. + node ([ast.ClassDef]): An ast.ClassDef node. + """ + # change class name + node.name = stree.get_opt_cls_name() + + stree.set_class_ast(node) + + for body in node.body: + if isinstance(body, ast.FunctionDef): + if body.name == "__init__": + ClassDefParser._process_init_func_ast(body, stree.get_ori_cls_name(), stree.get_opt_cls_name()) + stree.set_init_func_ast(body) + elif body.name == "construct": + parser: Parser = ParserRegister.instance().get_parser(ast.FunctionDef) + parser.process(stree, body) + else: + logger.warning( + "Ignoring ast.FunctionDef in ast.ClassDef except __init__ and construct function: %s", + body.name) + else: + logger.warning("Ignoring unsupported node(%s) in ast.ClassDef.", type(body).__name__) + + +g_classdef_parser = reg_parser(ClassDefParser()) diff --git a/mindspore/python/mindspore/rewrite/parsers/function_def_parser.py b/mindspore/python/mindspore/rewrite/parsers/function_def_parser.py new file mode 100644 index 00000000000..fbde74543dc --- /dev/null +++ b/mindspore/python/mindspore/rewrite/parsers/function_def_parser.py @@ -0,0 +1,53 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Parse bodies of ast.FunctionDef which is construct function to nodes of SymbolTree.""" +import ast + +from ..parser_register import ParserRegister, reg_parser +from ..parser import Parser +from ..symbol_tree import SymbolTree + + +class FunctionDefParser(Parser): + """Parse bodies of ast.FunctionDef which is construct function to nodes of SymbolTree.""" + + def target(self): + """Parse target type""" + return ast.FunctionDef + + def process(self, stree: SymbolTree, node: ast.FunctionDef): + """Parse bodies of ast.FunctionDef which is construct function to nodes of SymbolTree.""" + stree.set_ast_root(node) + # parse args as inputs of stree + arguments: ast.arguments = node.args + parser: Parser = ParserRegister.instance().get_parser(ast.arguments) + parser.process(stree, arguments) + + # parse body as node of stree + for body in node.body: + # avoid add dead code, so we need to break if return is added. + parser: Parser = ParserRegister.instance().get_parser(type(body)) + if parser is None: + stree.append_python_node(node, body) + else: + parser.process(stree, body) + + if hasattr(node, "decorator_list"): + stree.try_append_python_node(node, node.decorator_list) + if hasattr(node, "returns"): + stree.try_append_python_node(node, node.returns) + + +g_functiondef_parser = reg_parser(FunctionDefParser()) diff --git a/mindspore/python/mindspore/rewrite/parsers/module_parser.py b/mindspore/python/mindspore/rewrite/parsers/module_parser.py new file mode 100644 index 00000000000..3026f58fa55 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/parsers/module_parser.py @@ -0,0 +1,124 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Parse ast.Module to SymbolTrees.""" +from typing import Any +import os +import ast +import copy +import inspect +import astunparse + +from mindspore import log as logger +from ..symbol_tree import SymbolTree +from ..parser import Parser +from ..parser_register import ParserRegister, reg_parser + + +class ClassFinder(ast.NodeVisitor): + """Find all ast.ClassDef in input ast node.""" + + def __init__(self): + """Keep all found ast.ClassDef in self._classes""" + self._classes: [ast.ClassDef] = [] + + def visit_ClassDef(self, node: ast.ClassDef) -> Any: + """Iterate over all nodes and save ast.ClassDef nodes.""" + self._classes.append(node) + + def find_all_classes(self, node: ast.AST) -> [ast.ClassDef]: + """Interface of ClassFinder.""" + self.visit(node) + return self._classes + + +class ModuleParser(Parser): + """Parse ast.Module to SymbolTrees.""" + + def target(self): + """Parse target type""" + return ast.Module + + @staticmethod + def _find_class(ast_node: ast.Module) -> ast.ClassDef: + """Find all ast.ClassDef in ast.Module, only support one ast.ClassDef in ast.Module now.""" + visitor = ClassFinder() + classes = visitor.find_all_classes(ast_node) + if not classes: + raise RuntimeError("No class in module") + if len(classes) > 1: + raise RuntimeError("Multi-class in module is not supported now") + return classes[0] + + @staticmethod + def get_import_node(ast_root): + """Iterate over ast_root and return all ast.Import nodes or ast.ImportFrom nodes in ast_root.""" + import_nodes = [] + + class GetImportNode(ast.NodeVisitor): + """Find all import nodes from input ast node.""" + def visit_Import(self, node: ast.Import) -> Any: + """Iterate over all nodes and save ast.Import nodes.""" + import_nodes.append(copy.deepcopy(node)) + return node + + def visit_ImportFrom(self, node: ast.ImportFrom) -> Any: + """Iterate over all nodes and save ast.ImportFrom nodes.""" + import_nodes.append(copy.deepcopy(node)) + return node + + def get_node(self, input_ast): + """Interface of GetImportNode.""" + self.generic_visit(input_ast) + return True + + get_node_handler = GetImportNode() + get_node_handler.get_node(ast_root) + return import_nodes + + @staticmethod + def _add_import_to_module(module: ast.Module, origin_net): + """Insert two groups of import nodes to ast.Module, common ones and those from class definition file.""" + module.body.insert(0, ast.Import([ast.alias(name='mindspore', asname=None)])) + module.body.insert(1, ast.ImportFrom(module='mindspore', names=[ast.alias(name='nn', asname=None)], level=0)) + module.body.insert(2, ast.ImportFrom(module='mindspore.nn', names=[ast.alias(name='Cell', asname=None)], + level=0)) + origin_net_source_code_file = inspect.getfile(type(origin_net)) + if not os.path.exists(origin_net_source_code_file): + raise RuntimeError("File ", origin_net_source_code_file, " not exist") + try: + with open(origin_net_source_code_file, "r") as f: + source_code = f.read() + import_nodes = ModuleParser.get_import_node(ast.parse(source_code)) + except RuntimeError: + raise RuntimeError("get import nodes error") + if import_nodes: + for import_index, import_node in enumerate(import_nodes): + module.body.insert(import_index + 3, import_node) + ast.fix_missing_locations(module) + + def process(self, stree: SymbolTree, node: ast.Module): + """Process ast.ClassDef nodes in ast.Module.""" + ModuleParser._add_import_to_module(node, stree.get_origin_network()) + class_ast = ModuleParser._find_class(node) + stree.set_class_ast(class_ast) + for body in node.body: + if isinstance(body, ast.ClassDef): + parser: Parser = ParserRegister.instance().get_parser(ast.ClassDef) + parser.process(stree, body) + else: + logger.info(f"Ignoring unsupported node({astunparse.unparse(body)}) in ast.Module.") + + +g_module_parser = reg_parser(ModuleParser()) diff --git a/mindspore/python/mindspore/rewrite/parsers/return_parser.py b/mindspore/python/mindspore/rewrite/parsers/return_parser.py new file mode 100644 index 00000000000..d70c707f321 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/parsers/return_parser.py @@ -0,0 +1,40 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Parse ast.Return output-node of SymbolTree.""" +import ast + +from ..symbol_tree import SymbolTree +from ..node import Node +from ..parser import Parser +from ..parser_register import reg_parser + + +class ReturnParser(Parser): + """Parse ast.Return output-node of SymbolTree.""" + + def target(self): + """Parse target type""" + return ast.Return + + def process(self, stree: SymbolTree, node: ast.Return): + """Parse ast.Return to output-node of SymbolTree.""" + return_value = node.value + if not isinstance(return_value, ast.Name): + raise RuntimeError("Only ast.Name as return value") + node_return = Node.create_output_node(node, [return_value.id]) + stree.append_origin_field(node_return) + + +g_return_parser = reg_parser(ReturnParser()) diff --git a/mindspore/python/mindspore/rewrite/symbol_tree.py b/mindspore/python/mindspore/rewrite/symbol_tree.py new file mode 100644 index 00000000000..96800b2300e --- /dev/null +++ b/mindspore/python/mindspore/rewrite/symbol_tree.py @@ -0,0 +1,899 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""SymbolTree class define of Rewrite according to forward function of a network.""" +from typing import Optional, Union, Tuple +import os +import sys +import ast +import tempfile +import astunparse + +from mindspore.nn import Cell +from mindspore import log as logger +from .node import Node, TreeNode +from .api.node_type import NodeType +from .ast_modifier import AstModifier +from .api.scoped_value import ScopedValue, ValueType +from .symbol_tree_dumper import SymbolTreeDumper +from .topological_manager import TopoManager +from .namer import TargetNamer, NodeNamer + + +class Position: + """Position indicates a source code position in one network.""" + + def __init__(self, symbol_tree, node, before_node: bool): + """ + Constructor of Position. + Recommend to use class method of position rather than constructor of Position. + + Args: + symbol_tree (SymbolTree): A handler of SymbolTree indicated position in which SymbolTree. + node (Node): A handler of Node indicated position is around which Node. + before_node (bool): A bool indicated position is before or after the 'node'. + """ + self.symbol_tree = symbol_tree + self.node = node + self.before_node = before_node + + @classmethod + def create(cls, symbol_tree, node, before_node): + """ + Class method of Position. Return None when symbol_tree or node is None. + + Args: + symbol_tree: A handler of SymbolTree indicated position in which SymbolTree. + node: A handler of Node indicated position is around which Node. + before_node (bool): A bool indicated position is before or after the 'node'. + + Returns: + A Position. + """ + if symbol_tree is None or node is None: + return None + return Position(symbol_tree, node, before_node) + + +class SymbolTree: + """A symbol-tree usually corresponding to forward method of a network.""" + + def __init__(self, origin_network: Cell, module_ast: ast.Module): + """ + Constructor of SymbolTree. Rewrite recommend using SymbolTreeBuilder to instantiate an instance of SymbolTree + rather than invoking constructor of SymbolTree directly. + + Args: + origin_network (Cell): A handler to original network instance. + module_ast (ast.Module): An instance of ast.AST represents ast node of original network. + """ + origin_network_key = "handler" + # init unique-namers + self._target_namer = TargetNamer() + self._node_name_namer = NodeNamer() + # name or node would use as name of field, so name of origin network handler field should be added into \ + # _node_name_namer. + self._node_name_namer.add_name(origin_network_key) + self._topo_mgr = TopoManager() + + self._global_vars: {str, object} = {origin_network_key: origin_network} + self._nodes: {str, Node} = {} + # parameters of forward method + self._inputs: [Node] = [] + self._ori_cls_name = type(origin_network).__name__ + self._opt_cls_name = self._ori_cls_name + "Opt" + self._origin_network = origin_network + self._module_ast: ast.Module = module_ast + self._class_ast: Optional[ast.ClassDef] = None + self._root_ast: Optional[ast.FunctionDef] = None + self._init_func_ast: Optional[ast.FunctionDef] = None + + # head node is always point to the first node(in source code order) of SymbolTree + self._head = None + # tail node is always point to the last node(in source code order) of SymbolTree + self._tail = None + self._return: Optional[Node] = None + + def get_ori_cls_name(self) -> str: + """ + Get class name of original network. + + Returns: + A str represents class name of original network. + """ + return self._ori_cls_name + + def get_opt_cls_name(self) -> str: + """ + Get class name of rewritten network. + + Returns: + A str represents class name of rewritten network. + """ + return self._opt_cls_name + + def get_module_ast(self): + """ + Getter of `_module_ast`. + + Returns: + An instance of ast.AST represents ast node of corresponding module. + """ + return self._module_ast + + def get_ast_root(self): + """ + Getter of `_root_ast`. + + Returns: + An instance of ast.AST represents ast node of corresponding forward method. + """ + return self._root_ast + + def set_class_ast(self, ast_node: ast.ClassDef): + """ + Setter of `_class_ast`. + + Args: + ast_node (ast.ClassDef): An instance of ast.ClassDef represents ast node of corresponding network class. + """ + self._class_ast = ast_node + + def set_init_func_ast(self, ast_node: ast.FunctionDef): + """ + Setter of `_init_func_ast`. + + Args: + ast_node (ast.FunctionDef): An instance of ast.FunctionDef represents ast node of init method of + corresponding network class. + """ + self._init_func_ast = ast_node + + def set_ast_root(self, ast_node: ast.FunctionDef): + """ + Setter of `_root_ast`. + + Args: + ast_node (ast.FunctionDef): An instance of ast.FunctionDef represents ast node of forward method of + corresponding network class. + """ + self._root_ast = ast_node + + def get_inputs(self): + """ + Getter of `_inputs` which represents parameters of current forward method. + + Returns: + A list of instance of Node whose node_type is NodeType.Input as input nodes. + """ + return self._inputs + + def get_head_node(self): + """ + Getter of `_head` which represents the beginning node while iterating SymbolTree nodes. + + Returns: + An instance of node. + """ + return self._head + + def get_return_node(self): + """ + Getter of `_return` which represents return statement of forward method of network. + + Returns: + An instance of node. + """ + return self._return + + def get_origin_network(self): + """ + Getter of `_origin_network`. + + Returns: + An instance of Cell which represents original network. + """ + return self._origin_network + + def nodes(self, unfold_subtree=True): + """ + Getter of nodes if current SymbolTree. + + Args: + unfold_subtree (bool): Need to iterate into sub-symbol-tree recursively. + + Returns: + A list of instance of Nodes. + """ + if unfold_subtree: + nodes = [] + for _, v in self._nodes.items(): + if isinstance(v, TreeNode): + nodes.extend(self.nodes(v.symbol_tree)) + else: + nodes.append(v) + return nodes + return self._nodes.values() + + def get_node(self, node_name: str) -> Optional[Node]: + """ + Get node of current symbol_tree by `node_name`. + + Args: + node_name (str): A str represents name of node as key of query. + + Returns: + An instance of Node if found else None. + """ + + return self._nodes.get(node_name) + + def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]: + if isinstance(node_or_name, Node): + return self.get_node(node_or_name.get_name()) + if isinstance(node_or_name, str): + return self.get_node(node_or_name) + return None + + def get_node_inputs(self, node_or_name: Union[Node, str]) -> [Node]: + """ + Getter of inputs in topological relation of current 'node_or_name'. + + Args: + node_or_name (Union[Node, str]): An instance of node or a str represents name of node. + + Returns: + A list of instances of Node as input nodes if 'node_or_name' belong to current SymbolTree. An empty list if + 'node_or_name' not belong to current SymbolTree. + """ + + real_node: Optional[Node] = self._get_real_node(node_or_name) + if real_node is None: + logger.info("Node(%s) is not belong to current SymbolTree", node_or_name) + return [] + return node_or_name.get_inputs() + + def get_node_users(self, node_or_name: Union[Node, str]) -> [Tuple[Node, int]]: + """ + Getter of outputs in topological relation of current 'node_or_name'. + + Args: + node_or_name (Union[Node, str]): An instance of node or a str represents name of node. + + Returns: + A list of instances of Node as output nodes if 'node_or_name' belong to current SymbolTree. An empty list if + 'node_or_name' not belong to current SymbolTree. + """ + + real_node: Optional[Node] = self._get_real_node(node_or_name) + if real_node is None: + logger.info("Node(%s) is not belong to current SymbolTree", node_or_name) + return [] + return self._topo_mgr.get_node_users(node_or_name) + + def before(self, node_or_name: Union[Node, str]) -> Position: + """ + Get insert position before 'node_or_name' in source code list. + Consider using symbol_tree, node and before/after as position for sub-tree feature. + + Note: + Topological order is not determined here which is determined by arguments of node and updated by + TopologicalManager automatically. + + Args: + node_or_name (Union[Node, str]): An instance of node or a str represents name of node. + + Returns: + A Position represents an insert point. + + Raises: + AssertError: If 'node_or_name' is not a Node or a str + RuntimeError: If 'node_or_name' is not belong to this SymbolTree or any sub-SymbolTree of current + SymbolTree. + """ + + node = self._get_real_node(node_or_name) + if node is None: + raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name) + return Position.create(node.get_belong_symbol_tree(), node, True) + + def after(self, node_or_name: Union[Node, str]) -> Position: + """ + Get insert position after 'node_or_name' in source code list. + Consider using symbol_tree, node and before/after as position for sub-tree feature. + + Note: + Topological order is not determined here which is determined by arguments of node and updated by + TopologicalManager automatically. + + Args: + node_or_name (Union[Node, str]): An instance of node or a str represents name of node. + + Returns: + A Position represents an insert point. + + Raises: + AssertError: If 'node_or_name' is not a Node or a str + RuntimeError: If 'node_or_name' is not belong to this SymbolTree or any sub-SymbolTree of current + SymbolTree. + """ + + node = self._get_real_node(node_or_name) + if node is None: + raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name) + return Position.create(node.get_belong_symbol_tree(), node, False) + + def insert_node(self, position: Position, node: Node, insert_to_ast: bool = True) -> Node: + """ + Insert a node into SymbolTree. + Note: + Name of node will be unique while inserting node into SymbolTree. + + ValueType.CustomObjValue type arguments will be converted to ValueType.NamingValue and custom object will + be saved in global_vars dict while inserting node into SymbolTree. + + Targets of node will be unique while inserting node into SymbolTree. + + A field instantiation statement will be added into "init" function of network class using node name as field + name when `insert_to_ast` is True while inserting node into SymbolTree. + + An assign statement represents invoking to this node will be added into forward function of network class + corresponding to field-instantiation-statement when `insert_to_ast` is True while inserting node into + SymbolTree. + + Topological relation is updated and inputs of corresponding node is updated. + + Args: + position (Position): A Position indicates an insert position point. + node (Node): An instance of node to be inserted in. + insert_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is + True. + + Returns: + An instance of node which has been inserted into SymbolTree. + + Raises: + RuntimeError: If 'position' is not in current SymbolTree. + RuntimeError: If corresponding ast node is not an ast.Assign when 'insert_to_ast' is True. + """ + + # if position in current SymbolTree + if position is not None and position.symbol_tree is not self: + raise RuntimeError("Position is not in current SymbolTree:", position) + # unique targets, name while insert node into symbol_tree + node_name = self._node_name_namer.get_name(node) + node.set_name(node_name) + self._handle_custom_obj_in_normalized_args(node) + # _unique_targets must called after _update_args_for_unique and _update_kwargs_for_unique + self._unique_targets(node) + self._insert_node(position, node) + # update init-function-ast and construct-function-ast + if insert_to_ast: + node.set_func(ScopedValue.create_naming_value(node_name, "self")) + node_ast = node.get_ast() + if not isinstance(node_ast, ast.Assign): + raise RuntimeError("Only support insert cell op now") + AstModifier.insert_assign_to_function(self._init_func_ast, + targets=[ScopedValue(ValueType.NamingValue, "self", node_name)], + expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"), + args=[ScopedValue(ValueType.StringValue, "", node_name)]) + AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast, + None if position is None else position.node.get_ast(), + position.before_node) + self._global_vars[node_name] = node.get_instance() + return node + + def append_node(self, node: Node, append_to_ast: bool = True) -> Node: + """ + Append a node to SymbolTree. + + Args: + node (Node): An instance of node to be appended. + append_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is + True. + + Returns: + An instance of node which has been appended to SymbolTree. + """ + return self.insert_node(Position.create(self, self._tail, False), node, append_to_ast) + + def append_origin_field(self, node: Node) -> Node: + """ + Append an original field node to SymbolTree. An original field node represents a node created from existing + statement in forward method, from existing ast node in ast of forward method, so ast node do not need to update + while these nodes appending to SymbolTree. + This method is called while building SymbolTree usually. + + Args: + node (Node): An instance of node to be appended. + + Returns: + An instance of node which has been appended to SymbolTree. + """ + self._update_args_kwargs_for_unique(node) + if node.get_node_type() == NodeType.Output: + self._return = node + elif node.get_node_type() == NodeType.Input: + self._inputs.append(node) + return self.append_node(node, False) + + def append_input_node(self, param_name: str, default: Optional[ScopedValue] = None): + """ + Append an input node to SymbolTree corresponding to parameter of forward method of network class. + This method is called while building SymbolTree usually. + + Args: + param_name (str): A str represents name of parameter of forward method of network class. + default (Optional[ScopedValue] ): A ScopedValue represents default value of parameter. Default is None which + means parameter has no default value. + + Returns: + An instance of input node which has been appended to SymbolTree. + """ + if param_name == "self": + return + for input_node in self._inputs: + targets = input_node.get_targets() + if len(targets) != 1: + raise RuntimeError("targets should have 1 elements") + target: ScopedValue = targets[0] + if target.type != ValueType.NamingValue: + raise RuntimeError("target.type should equal to ValueType.NamingValue") + if target.scope != "": + raise RuntimeError("target.scope should be empty") + exist_param = target.value + if exist_param == param_name: + raise RuntimeError("input duplicated:", param_name) + input_node = Node.create_input_node(None, param_name, default, name=f"input_{param_name}") + self.append_origin_field(input_node) + + def try_append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST) -> Optional[Node]: + """ + Try appending a python node to SymbolTree if 'ast_node' is not None and 'ast_node' is not Empty if 'ast_node' is + a list or a dict. + This method is called while building SymbolTree usually. + + Args: + ast_scope (ast.AST): A ast node represents ast node of scope of node. + ast_node (ast.AST): A ast node represents ast node. + + Returns: + An instance of python node if a new node has been appended to SymbolTree else None. + """ + if ast_node is None: + return None + if isinstance(ast_node, (list, dict)) and not ast_node: + return None + return self.append_python_node(ast_scope, ast_node) + + def append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST) -> Node: + """ + Append a python node to SymbolTree. + This method is called while building SymbolTree usually. + + Args: + ast_scope (ast.AST): A ast node represents ast node of scope of node. + ast_node (ast.AST): A ast node represents ast node. + + Returns: + An instance of python node which has been appended to SymbolTree. + """ + logger.info("Ignoring unsupported node(%s) in %s.", type(ast_node).__name__, type(ast_scope).__name__) + node_name = self._node_name_namer.get_name(type(ast_node).__name__) + node = Node.create_python_node(ast_node, node_name) + self._insert_node(Position.create(self, self._tail, True), node) + return node + + def set_output(self, return_value: str, index: int) -> Node: + """ + Update return value of return of forward method of network class. + + Args: + return_value (str): A str represents new return value. + index (int): A int indicates which return value to be updated. + + Returns: + An instance of node represents return node after updated. + """ + if self._return is None: + raise RuntimeError("SymbolTree has no output") + self.set_node_arg(self._return, index, return_value) + return self._return + + def erase_node(self, node_or_name: Union[Node, str]) -> Node: + """ + Erase a node from SymbolTree. + Note: + If node is depended on by other node, RuntimeError will raise. + + Topological relation is updated. + + Args: + node_or_name (Union[Node, str]): An instance of node or a str represents name of node. + + Returns: + An instance of node which has been erased from SymbolTree. + + Raises: + RuntimeError: If 'node_or_name' is not in current SymbolTree. + RuntimeError: If erase corresponding ast node failed. + """ + + node = self._get_real_node(node_or_name) + if node is None: + raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name) + ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast()) + if not ret: + raise RuntimeError("node not in function ast tree.") + for key, value in self._nodes.items(): + if id(value) == id(node): + self._nodes.pop(key) + value.isolate() + break + self._topo_mgr.on_erase_node(node) + return node + + def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node: + """ + Insert a node-tree into SymbolTree. + Note: + Inputs of intra sub-tree nodes need to be welly set. + + Inputs of inter sub-tree nodes will be updated by Rewrite automatically. + + Args: + position (Position): A Position indicates an insert position point. + root (Node): An instance of node as root of node-tree to be inserted in. + insert_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is + True. + + Returns: + An instance of node as root node of node-tree which has been inserted into SymbolTree. + + Raises: + RuntimeError: If 'position' is not in current SymbolTree. + """ + + # if position not in current SymbolTree + if position.symbol_tree is not self: + raise RuntimeError("Position is not in current SymbolTree: ", position) + + queue: [Node] = [root] + todos: [] = [] + inputs_list: [] = [] + while queue: + cur_node = queue.pop(0) + if cur_node in todos: + continue + todos.append(cur_node) + node_inputs = cur_node.get_inputs() + inputs_list.append(node_inputs) + for node_input in node_inputs: + if node_input is not None: + queue.append(node_input) + todos.reverse() + inputs_list.reverse() + for index, todo in enumerate(todos): + self.insert_node(position, todo, insert_to_ast) + position = self.after(todo) + # relink input of node + original_inputs = inputs_list[index] + for arg_idx, original_input in enumerate(original_inputs): + if original_input is not None: + self.set_node_arg_by_node(todo, arg_idx, original_input) + return root + + @staticmethod + def _link_nodes_and_find_root(nodes: [Node]) -> Node: + """ + Find inputs for all nodes created by Replacement according to their targets and arguments. + Find root node of all nodes created by Replacement. One and Only one root should be found. + + Args: + nodes ([Node]): A list of instance of Node created by Replacement. + + Returns: + An instance of Node represents root of input nodes. + """ + consumers: [ScopedValue] = [] + target_dict: {ScopedValue: Node} = {} + for node in nodes: + consumers.extend(node.get_args()) + for _, arg in node.get_kwargs(): + consumers.append(arg) + for target in node.get_targets(): + if target_dict.get(target) is not None: + raise RuntimeError("Target of node duplicated") + target_dict[target] = node + # find root node + root = None + for node in nodes: + used = 0 + for target in node.get_targets(): + if target in consumers: + used += 1 + if used == 0: + if root is not None: + raise RuntimeError("Replacement should only has one root") + root = node + if root is None: + raise RuntimeError("No root node found in replacement nodes") + # link node's input + for node in nodes: + inputs = [] + for _, arg in node.get_normalized_args().items(): + node_input: Node = target_dict.get(arg) + if node_input is None: + inputs.append(None) + else: + inputs.append(node_input) + node.set_inputs(inputs) + return root + + def replace(self, old_node: Node, new_nodes: [Node]) -> Node: + """ + Replace an old_node with a node_tree. 'new_node' is the root node of the node_tree. + Note: + Rewrite will iterate all nodes linked to this root node and insert these nodes into symbol_tree. + + Inputs of intra sub-tree nodes need to be welly set. + + Inputs of inter sub-tree nodes will be updated by Rewrite automatically. + + Args: + old_node (Node): Node to be replaced. + new_nodes ([Node]): Node tree to replace in. + + Returns: + An instance of Node represents root of node_tree been replaced in. + + Raises: + RuntimeError: If 'old_node' is isolated. + RuntimeError: If 'old_node' is not belong to current SymbolTree. + """ + + real_old_node = self._get_real_node(old_node) + if real_old_node is None: + raise RuntimeError("Old node is not belong to current SymbolTree:", old_node) + # get position + next_node: Node = old_node.get_next() + prev_node: Node = old_node.get_prev() + if prev_node is None and next_node is None: + raise RuntimeError("Try replacing a isolated node: ", old_node) + if next_node is None: + position = self.after(prev_node) + else: + position = self.before(next_node) + # insert node first, because targets of new_node is determined after insert + new_tree_root = SymbolTree._link_nodes_and_find_root(new_nodes) + new_node = self._insert_tree(position, new_tree_root) + # use targets of insert tree to redirect edge + users = self.get_node_users(old_node) + if len(new_node.get_targets()) != 1: + raise RuntimeError("targets of new_node should have 1 elements") + for user in users: + self.set_node_arg_by_node(user[0], user[1], new_node) + # erase old_node after edge is redirected because node can be erased only when node is isolated topologically + self.erase_node(old_node) + return new_node + + def set_node_arg(self, node: Union[Node, str], index: int, arg: Union[ScopedValue, str]): + """ + Set argument of 'node'. + + Args: + node (Union[Node, str]): Node to be modified. Can be a node or name of node. + index (int): Indicate which input being modified. + arg (Union[ScopedValue, str]): New argument to been set. + + Raises: + RuntimeError: If 'node' is not belong to current SymbolTree. + """ + + real_node = self._get_real_node(node) + if real_node is None: + raise RuntimeError("Node is not belong to current SymbolTree: ", node) + + new_arg, old_arg = node.set_arg(arg, index) + self._topo_mgr.on_update_arg(node, index, old_arg, new_arg) + + def set_node_arg_by_node(self, dst_node: Union[Node, str], arg_idx: int, src_node: Union[Node, str], + out_idx: Optional[int] = None): + """ + Set argument of 'dst_node' by another Node. + + Args: + dst_node (Node): Node to be modified. Can be a node or name of node. + arg_idx (int): Indicate which input being modified. + src_node (Node): Node as new input. Can be a node or name of node. + out_idx (Optional[int]): Indicate which output of 'src_node' as new input of 'dst_node'. Default is None + which means use first output of 'node_to_link' as new input. + + Raises: + RuntimeError: If 'dst_node' is not belong to current SymbolTree. + RuntimeError: If 'src_node' is not belong to current SymbolTree. + RuntimeError: If 'out_idx' is out of range. + RuntimeError: If 'src_node' has multi-outputs while 'out_idx' is None or 'out_idx' is not offered. + """ + + real_dst_node = self._get_real_node(dst_node) + if real_dst_node is None: + raise RuntimeError("dst_node is not belong to current SymbolTree: ", dst_node) + real_src_node = self._get_real_node(src_node) + if real_src_node is None: + raise RuntimeError("src_node is not belong to current SymbolTree: ", src_node) + + targets = real_src_node.get_targets() + if out_idx is None: + if len(targets) != 1: + raise RuntimeError("node should has one output when out_idx is not provided") + out_idx = 0 + if out_idx >= len(targets): + raise RuntimeError("out_idx out of range: ", out_idx) + new_arg = targets[out_idx] + self.set_node_arg(real_dst_node, arg_idx, new_arg) + + def dump(self): + """Dump graph.""" + dump_st = SymbolTreeDumper(self) + dump_st.dump() + + def get_code(self) -> str: + """ + Get source code of modified network. + + Returns: + A str represents source code of modified network. + """ + ast.fix_missing_locations(self._module_ast) + return astunparse.unparse(self._module_ast) + + def get_network(self): + """ + Get modified network. + + Returns: + A network object. + """ + cls = self._get_cls_through_file() + return cls(self._global_vars) + + def _unique_targets(self, node: Node): + """ + Unique targets of node by _target_namer. + + Args: + node (Node): A Node whose targets to be uniqued. + """ + new_targets: [ScopedValue] = [] + if node.get_targets() is None: + return + for target in node.get_targets(): + if not isinstance(target, ScopedValue): + raise TypeError("target should be ScopedValue, got: ", type(target)) + unique_target = self._target_namer.get_name(target.value) + new_targets.append(ScopedValue.create_naming_value(unique_target, target.scope)) + node.set_targets(new_targets) + + def _update_args_kwargs_for_unique(self, node: Node): + """ + Update arguments and keyword arguments of node because unique-ing of targets of other nodes. + + Args: + node (Node): A Node whose arguments and keyword arguments to be updated. + """ + result: {str: ScopedValue} = {} + if node.get_normalized_args() is None: + return + for key, arg in node.get_normalized_args().items(): + if not isinstance(arg, ScopedValue): + raise TypeError("arg should be ScopedValue, got: ", type(arg)) + if arg.type == ValueType.NamingValue: + # unique name + new_arg = ScopedValue(arg.type, arg.scope, self._target_namer.get_real_arg(arg.value)) + result[key] = new_arg + else: + result[key] = arg + node.set_normalized_args(result) + + def _add_node2nodes(self, node: Node): + """ + Add `node` to `_nodes` dict. + + Args: + node (Node): A Node to be added into `_nodes`. + + Raises: + RuntimeError: If name of the node is duplicated. + """ + node_name = node.get_name() + if self._nodes.get(node_name) is not None: + raise RuntimeError("generated duplicated node name", node_name, self._nodes.get(node_name), + node) + self._nodes[node_name] = node + + def _insert_node(self, position: Optional[Position], node: Node): + """ + Insert a node into SymbolTree. + 1. Add `node` to `_nodes`. + 2. Insert `node` to node list(source code order). + 3. Update topological relation and update inputs of `node`. + + Args: + position (Optional[Position]): Indicates node insert position. Position is None when inserting first node of + SymbolTree. + node (Node): A Node to be inserted into SymbolTree. + + Raises: + RuntimeError: Position is None when _nodes of SymbolTree is not Empty. It means position can not be None + unless inserting first node. + """ + if position is None: + if self._nodes: + raise RuntimeError("self._nodes should be empty") + self._head = node + else: + if position.before_node: + position.node.insert_before(node) + else: + position.node.insert_after(node) + self._tail = node + self._add_node2nodes(node) + self._topo_mgr.on_insert_node(node) + node.set_belong_symbol_tree(self) + + def _handle_custom_obj_in_normalized_args(self, node: Node): + """ + Convert CustomObjValue type argument to NamingValue type argument by storing custom object in global_vars dict. + + Args: + node (Node): A Node whose arguments and keyword arguments to be handled. + """ + result: {str, ScopedValue} = {} + for arg, value in node.get_normalized_args().items(): + if not isinstance(value, ScopedValue): + raise TypeError("value should be ScopedValue, got: ", type(value)) + if value.type == ValueType.CustomObjValue: + field = self._node_name_namer.get_name(f"var_{type(value.value).__name__}") + self._global_vars[field] = value.value + init_targets = [ScopedValue.create_naming_value(field, "self")] + AstModifier.append_global_vars_expr_to_init(self._init_func_ast, init_targets, field) + result[arg] = init_targets[0] + else: + result[arg] = value + node.set_normalized_args(result) + + def _get_cls_through_file(self): + """ + Load rewritten network class of current SymbolTree. + 1. Get source code of current SymbolTree. + 2. Saving source code to a tempfile. + 3. Import rewritten network class using "__import__" function. + + Returns: + A class handle. + """ + source = self.get_code() + tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.py') + tmp_file.write(source.encode('utf8')) + tmp_file.flush() + tmp_file_name = tmp_file.name + tmp_module_path, tmp_module_file = os.path.split(tmp_file_name) + tmp_module_name = tmp_module_file[:-3] + sys.path.append(tmp_module_path) + tmp_module = __import__(tmp_module_name) + network_cls = getattr(tmp_module, self._opt_cls_name) + if network_cls is None: + raise RuntimeError("Can not find network class:", self._opt_cls_name) + return network_cls diff --git a/mindspore/python/mindspore/rewrite/symbol_tree_builder.py b/mindspore/python/mindspore/rewrite/symbol_tree_builder.py new file mode 100644 index 00000000000..6fd65889dc0 --- /dev/null +++ b/mindspore/python/mindspore/rewrite/symbol_tree_builder.py @@ -0,0 +1,73 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""SymbolTree builder.""" +from typing import Optional +import ast +import inspect + +from mindspore.nn import Cell +from .symbol_tree import SymbolTree +from .parser_register import ParserRegister +from .parser import Parser +from .ast_transformers import FlattenRecursiveStmt + + +class SymbolTreeBuilder: + """SymbolTree builder.""" + + def __init__(self, network: Cell): + """ + Constructor of SymbolTreeBuilder. + + Args: + network (Cell): An instance of Cell represents a network from which SymbolTree will be built. + """ + if not isinstance(network, Cell): + raise RuntimeError("Only support network with Cell type now") + self._origin_net = network + network_str = inspect.getsource(type(network)) + self._ast_root: ast.Module = ast.parse(network_str) + self._root_tree: Optional[SymbolTree] = None + + @staticmethod + def _ast_transform(ast_root: ast.AST) -> ast.AST: + """ + Optimize ast before parse. + + Args: + ast_root (ast.AST): An instance of ast to be optimized. + + Returns: + An instance of ast been optimized. + """ + transform_list = [FlattenRecursiveStmt()] + for transformer in transform_list: + ast_root = transformer.transform(ast_root) + return ast_root + + def build(self) -> SymbolTree: + """ + Build SymbolTree. + + Returns: + An instance of SymbolTree. + """ + self._ast_root = SymbolTreeBuilder._ast_transform(self._ast_root) + if not isinstance(self._ast_root, ast.Module): + raise RuntimeError("ast_root should be a ast.Module") + self._root_tree: SymbolTree = SymbolTree(self._origin_net, self._ast_root) + parser: Parser = ParserRegister.instance().get_parser(ast.Module) + parser.process(self._root_tree, self._ast_root) + return self._root_tree diff --git a/mindspore/python/mindspore/rewrite/symbol_tree_dumper.py b/mindspore/python/mindspore/rewrite/symbol_tree_dumper.py new file mode 100644 index 00000000000..f3b5f1200ee --- /dev/null +++ b/mindspore/python/mindspore/rewrite/symbol_tree_dumper.py @@ -0,0 +1,144 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""SymbolTree dumper.""" +import inspect + +from mindspore import log as logger +from .node import Node +from .api.node_type import NodeType +from .api.scoped_value import ScopedValue, ValueType + + +class SymbolTreeDumper: + """SymbolTree dumper.""" + + def __init__(self, symbol_tree): + """ + Constructor of SymbolTreeDumper. + + Args: + symbol_tree (SymbolTree): An instance of SymbolTree to be dumped. + """ + self._symbol_tree = symbol_tree + self._dump_buffer = "" + self._dump_key2index = {} + + def _reset(self): + """Reset SymbolTreeDumper.""" + self._dump_buffer = "" + self._dump_key2index = {} + + def _dump_global_info(self): + """Dump global info of SymbolTree.""" + self._dump_buffer += f"#SymbolTree entry : @construct \n" + + def _dump_inputs(self): + """Dump inputs of SymbolTree.""" + inputs = self._symbol_tree.get_inputs() + self._dump_buffer += f"#Inputs num : {len(inputs)}\n" + for single_input in inputs: + targets = single_input.get_targets() + if len(targets) != 1: + raise RuntimeError("Only support one output per node now") + target: ScopedValue = targets[0] + if target.type != ValueType.NamingValue: + raise RuntimeError("target.type should equal to ValueType.NamingValue") + if target.scope != "": + raise RuntimeError("target.scope should be empty") + input_arg = target.value + input_name = f"%input_{input_arg}" + if input_arg in self._dump_key2index.keys(): + raise RuntimeError("input_arg duplicated: ", input_arg) + self._dump_key2index[input_arg] = input_name + self._dump_buffer += f"{input_name}\n" + self._dump_buffer += f"\n" + + def _dump_nodes(self): + """Dump nodes of SymbolTree.""" + self._dump_buffer += f"Symbol Tree @construct {{ \n" + node_no = -1 + + node: Node = self._symbol_tree.get_head_node().get_next() + while node is not None: + if node.get_node_type() is NodeType.Output: + self._dump_buffer += f" Return(%{node_no}) \n" + self._dump_buffer += f" : (null) \n" + self._dump_buffer += f" # In file {inspect.getfile(type(self._symbol_tree.get_origin_network()))}" + + node = node.get_next() + continue + + node_no += 1 + self._dump_key2index[node.get_name()] = f"%{node_no}" + + targets = node.get_targets() + if not targets: + targets = [None] + op_type = node.get_instance_type() + if hasattr(op_type, "__name__"): + op_type_name = op_type.__name__ + else: + if hasattr(type(op_type), "__name__"): + op_type_name = type(op_type).__name__ + else: + raise RuntimeError("op has no attr __name__") + self._dump_buffer += f" %{node_no}({targets[0]}) = {op_type_name}" + + args = node.get_normalized_args().values() + if args: + arg_str = f"" + for arg in args: + if isinstance(arg, str): + arg_name = arg + elif isinstance(arg, ScopedValue): + arg_name = arg.value + else: + raise RuntimeError(f"arg type {type(arg)} of {arg} is not supported now") + + if arg_name in self._dump_key2index.keys(): + arg_str += f"{self._dump_key2index[arg_name]}, " + else: + logger.warning("arg not appears before") + arg_str += f"{arg_name}, " + self._dump_buffer += f"({arg_str[:-2]})" + + self._dump_buffer += f"{{instance name: {node.get_name()}}}" + + self._dump_buffer += f" attributes {{" + attrs = node.get_attributes() + if attrs: + attrs_str = f"" + for attr in attrs: + if not isinstance(attr, str): + raise TypeError("attr should be str, got: ", type(attr)) + attrs_str += f"{attr}: {attrs[attr]}, " + self._dump_buffer += attrs_str[:-2] + self._dump_buffer += f"}}\n" + + self._dump_buffer += f" : (null) -> (null)\n" + cls_real_path = inspect.getfile(node.get_instance_type()) if node.get_instance() else None + self._dump_buffer += f" # In file {cls_real_path}\n" + self._dump_buffer += f" # In file {inspect.getfile(type(self._symbol_tree.get_origin_network()))}\n" + + node = node.get_next() + self._dump_buffer += f"}}\n" + + def dump(self): + """Dump SymbolTree.""" + self._reset() + self._dump_global_info() + self._dump_inputs() + self._dump_nodes() + print(self._dump_buffer) diff --git a/mindspore/python/mindspore/rewrite/topological_manager.py b/mindspore/python/mindspore/rewrite/topological_manager.py new file mode 100644 index 00000000000..574e732cb4a --- /dev/null +++ b/mindspore/python/mindspore/rewrite/topological_manager.py @@ -0,0 +1,189 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""SymbolTree topological-relationship manager.""" +from typing import Tuple + +from .api.scoped_value import ScopedValue +from .node import Node + + +class TopoManager: + """SymbolTree topological-relationship manager.""" + + def __init__(self): + """ + Constructor of TopoManager. + Init provider and consumer. + Key of dict is an instance of ScopedValue. + Value of dict is a tuple whose first is an instance of Node, whose second is an index. + It means node's index th arg is argument + """ + self._target_provider: {ScopedValue, (Node, int)} = {} + self._target_consumer: {ScopedValue, [(Node, int)]} = {} + + def get_node_users(self, node: Node) -> [Tuple[Node, int]]: + """ + Get all nodes which depend on node corresponding to node_or_name. + + Args: + node (Node): An instance of node. + + Returns: + A list of nodes represents node users. + """ + targets = node.get_targets() + results = [] + for target in targets: + consumers = self._target_consumer.get(target) + if consumers is None: + continue + results.extend(consumers) + unique_results = [] + for result in results: + if result not in unique_results: + unique_results.append(result) + return unique_results + + def _add_consumer(self, product: ScopedValue, consumer: Node, index): + """ + Add a consumer to consumer dict. + + Args: + product (ScopedValue): An instance of ScopedValue represents product to be consumed. + consumer (Node): An instance of Node represents consumer. + index (int): A int represents which input of consumer is the product. + """ + consumers = self._target_consumer.get(product) + if consumers is None: + self._target_consumer[product] = [(consumer, index)] + else: + self._target_consumer[product].append((consumer, index)) + + def _erase_provider(self, product: ScopedValue): + """ + Erase a provider from provider dict. + + Args: + product (ScopedValue): An instance of ScopedValue represents product to be erased. + """ + if self._target_provider.get(product) is not None: + self._target_provider.pop(product) + + def _erase_consumer(self, product: ScopedValue, consumer: Node): + """ + Erase a consumer from consumer dict. + + Args: + product (ScopedValue): An instance of ScopedValue represents product whose consumer would be erased. + consumer (Node): An instance of Node which would be erased as a consumer. + """ + consumers = self._target_consumer.get(product) + if consumers is None: + return + for i in range(len(consumers) - 1, -1, -1): + exist_ele = consumers[i] + if id(exist_ele[0]) == id(consumer): + consumers.pop(i) + + def _update_node_inputs(self, node: Node) -> [Node]: + """ + Update inputs of node by current provider dict and consumer dict. + + Args: + node (Node): An instance of Node whose inputs will be updated. + + Returns: + A list of instance of nodes represents inputs of node. + """ + if node.get_normalized_args() is None: + node.set_inputs([]) + return [] + inputs = [] + for arg in node.get_normalized_args().values(): + provider = self._target_provider.get(arg) + # some arg of some node may be self.xxx which is not an output of another node + if provider is not None: + inputs.append(provider[0]) + node.set_inputs(inputs) + return inputs + + def on_insert_node(self, node: Node): + """ + Update provider dict and consumer dict while inserting node into SymbolTree and update inputs of node by updated + provider dict and consumer dict. + + Args: + node (Node): An instance of Node which been inserted into SymbolTree. + """ + if node.get_targets() is not None: + for i in range(0, len(node.get_targets())): + target = node.get_targets()[i] + if self._target_provider.get(target) is not None: + raise RuntimeError("target duplicated:", target) + self._target_provider[target] = (node, i) + if node.get_normalized_args() is not None: + for index, arg in enumerate(node.get_normalized_args().values()): + self._add_consumer(arg, node, index) + self._update_node_inputs(node) + + def on_erase_node(self, node: Node): + """ + Update provider dict and consumer dict while erasing node from SymbolTree. + + Args: + node (Node): An instance of Node which been erased from SymbolTree. + """ + if node.get_targets() is not None: + for target in node.get_targets(): + consumers = self._target_consumer.get(target) + if consumers is not None and consumers: + raise RuntimeError("Only support erase isolated node: ", node.get_name(), target) + self._erase_provider(target) + if node.get_normalized_args() is not None: + for arg in node.get_normalized_args().values(): + self._erase_consumer(arg, node) + # clear inputs of node rather than call _update_node_inputs because node is already erase from consumer dict + node.set_inputs([]) + + def on_update_arg(self, node: Node, arg_idx: int, old_arg: ScopedValue, new_arg: ScopedValue): + """ + Update provider dict and consumer dict while updating argument of node and update inputs of node by updated + provider dict and consumer dict. + + Args: + node (Node): An instance of Node whose arguments being updated. + arg_idx (int): An int indicates which argument of node being updated. + old_arg (ScopedValue): An instance of ScopedValue represents original argument. + new_arg (ScopedValue): An instance of ScopedValue represents new argument. + """ + self._erase_consumer(old_arg, node) + self._add_consumer(new_arg, node, arg_idx) + self._update_node_inputs(node) + + def dump(self, title=""): + """ + Dump topological relation. + + Args: + title (str): A string as a title will be printed before dumping topological relation. + """ + print(f"{title}------------------------------------------------------------------------------------") + for k, v in self._target_provider.items(): + print(f"{v[0].get_name()} produces {k.value}") + for k, v in self._target_consumer.items(): + print(f"{k.value} is consumed by: ") + for ele in v: + print(ele[0].get_name()) + print(f"-----------------------------------------------------------------------------------------") diff --git a/requirements.txt b/requirements.txt index 3a781b5821f..28e7297e0aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,5 @@ pycocotools >= 2.0.2 # for st test tables >= 3.6.1 # for st test easydict >= 1.9 # for st test psutil >= 5.7.0 +astunparse >= 0.0 +astpretty >= 0.0 diff --git a/tests/ut/python/rewrite/__init__.py b/tests/ut/python/rewrite/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/ut/python/rewrite/test_flatten_recursive_stmt.py b/tests/ut/python/rewrite/test_flatten_recursive_stmt.py new file mode 100644 index 00000000000..f0e296eda8f --- /dev/null +++ b/tests/ut/python/rewrite/test_flatten_recursive_stmt.py @@ -0,0 +1,64 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast +import inspect + +from mindspore.nn import Cell, Conv2d, BatchNorm2d, ReLU +from mindspore.rewrite.ast_transformers.flatten_recursive_stmt import FlattenRecursiveStmt + + +class Network(Cell): + def __init__(self): + super().__init__() + self.conv = Conv2d(16, 16, 3) + self.bn = BatchNorm2d(16) + self.relu1 = ReLU() + self.relu2 = ReLU() + self.relu3 = ReLU() + + def construct(self, x): + x = self.conv(x + 1) + x = x + 1 * 5 + 4 / 2 + self.bn(x) + self.relu1(x * 5) + x = self.relu2(x + 1) + x = True and x or x + x = self.relu3(x) + return x + 3 + + +def _get_ast(): + source = inspect.getsource(Network) + return ast.parse(source) + + +def test_flatten(): + """ + Feature: Class FlattenRecursiveStmt. + Description: Apply FlattenRecursiveStmt on a simple network. + Expectation: Success. + """ + ast_node = _get_ast() + frs = FlattenRecursiveStmt() + frs.transform(ast_node) + assert len(ast_node.body) == 1 + ast_class = ast_node.body[0] + assert isinstance(ast_class, ast.ClassDef) + assert len(ast_class.body) == 2 + ast_init_func = ast_class.body[0] + assert isinstance(ast_init_func, ast.FunctionDef) + assert len(ast_init_func.body) == 6 + ast_construct_func = ast_class.body[1] + assert isinstance(ast_construct_func, ast.FunctionDef) + assert len(ast_construct_func.body) == 17 diff --git a/tests/ut/python/rewrite/test_net_simple.py b/tests/ut/python/rewrite/test_net_simple.py new file mode 100644 index 00000000000..ebebc8eb9c7 --- /dev/null +++ b/tests/ut/python/rewrite/test_net_simple.py @@ -0,0 +1,141 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore +from mindspore import Tensor, nn +from mindspore.rewrite import SymbolTree, ScopedValue, ValueType, Node +from mindspore.common.initializer import Normal +from mindspore.common.api import _cell_graph_executor + + +class SimpleNet(nn.Cell): + def __init__(self, num_class=10, num_channel=1): + super(SimpleNet, self).__init__() + self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) + self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) + self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) + self.var = 10 + + def construct(self, x): + x = self.conv1(x) + x = x + y = self.var + y = y * 5 + y = y and True + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class MyCell(nn.Cell): + def __init__(self): + super().__init__() + self.conv = nn.Dense(5, 16) + + def construct(self, x, y): + x = self.conv(x) + x = mindspore.ops.Add()(x, y) + return x + + +def add_conv_before_flatten(stree: SymbolTree): + new_conv_node = None + for node in stree.nodes(): + if node.get_instance_type() == mindspore.nn.Flatten: + position = stree.before(node) + new_conv = nn.Conv2d(16, 16, 3) + new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv', + args=[ScopedValue.create_naming_value('self_max_po')]) + stree.insert(position, new_conv_node) + break + if new_conv_node is not None: + for node in stree.nodes(): + if node.get_instance_type() == mindspore.nn.Flatten: + inputs = node.get_inputs() + assert len(inputs) == 1 + new_conv_node.set_arg_by_node(0, inputs[0]) + + +def add_my_cell_after_x_12(stree: SymbolTree): + for node in stree.nodes(): + targets = node.get_targets() + if targets is None: + continue + assert targets[0].type == ValueType.NamingValue + target = str(targets[0]) + if target == "x_12": + position = stree.after(node) + custom_cell = MyCell() + bias = Tensor(1, mindspore.int32) + new_custom_node = Node.create_call_cell(custom_cell, targets=['nx2'], + args=[ScopedValue.create_naming_value('nx3'), + ScopedValue.create_variable_value(bias)], name='my_cell') + stree.insert(position, new_custom_node) + new_custom_node.set_arg(0, "x_12") + break + + +def erase_node_x_11(stree: SymbolTree): + return_node = None + for node in stree.nodes(): + if node.get_targets() is None: + return_node = node + break + assert return_node is not None + for node in stree.nodes(): + targets = node.get_targets() + if targets is None: + continue + assert targets[0].type == ValueType.NamingValue + target = str(targets[0]) + if target == "x_11": + stree.set_output(0, "x_10") + stree.erase_node(node) + break + + +def transform(stree: SymbolTree): + add_conv_before_flatten(stree) + add_my_cell_after_x_12(stree) + erase_node_x_11(stree) + + +def test_simple_net(): + """ + Feature: Module rewrite. + Description: Resolve a simple network by rewrite and do some transform on it. + Expectation: Result of rewrite can be compiled. + """ + net = SimpleNet(10) + stree = SymbolTree(net) + transform(stree) + print("------------------------------------ keys of global_vars: ", + getattr(stree.get_handler(), "_global_vars").keys()) + net_opt = stree.get_network() + data_in = Tensor(np.ones([1, 1, 32, 32]), mindspore.float32) + _cell_graph_executor.compile(net_opt, data_in) diff --git a/tests/ut/python/rewrite/test_node.py b/tests/ut/python/rewrite/test_node.py new file mode 100644 index 00000000000..035d08d0f8e --- /dev/null +++ b/tests/ut/python/rewrite/test_node.py @@ -0,0 +1,126 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast + +from mindspore.nn import Cell +from mindspore.rewrite import ScopedValue +from mindspore.rewrite.node import Node + + +class FakeCell(Cell): + def construct(self, input1, input2, cool_boy=None): + return input1 + input2 + cool_boy + + +class FakeCell2(Cell): + def construct(self, a, b, d, e, *args, f=6, **kwargs): + return a + b + d + e + sum(args) + f + sum(kwargs.values()) + + +class FakeCell3(Cell): + def construct(self, a, b, *args, f=6, h=7, **kwargs): + return a + b + f + h + sum(args) + sum(kwargs.values()) + + +def test_create_by_cell(): + """ + Feature: Python api create_call_cell of Node of Rewrite. + Description: Call create_call_cell to create a node. + Expectation: Success. + """ + node = Node.create_call_cell(FakeCell(), None, ['x'], 'new_conv', + [ScopedValue.create_naming_value('x'), ScopedValue.create_variable_value(1)], + {"cool_boy": ScopedValue.create_naming_value('Naroto')}, 'new_conv') + assert node._args_num == 2 + assert node._kwargs_num == 1 + assert node._normalized_args_keys == ["input1", "input2", "cool_boy"] + + assert node._normalized_args == { + "input1": ScopedValue.create_naming_value('x'), + "input2": ScopedValue.create_variable_value(1), + "cool_boy": ScopedValue.create_naming_value('Naroto') + } + + ast_node: ast.Assign = node.get_ast() + assign_value: ast.Call = ast_node.value + args_ast = assign_value.args + keywords_ast = assign_value.keywords + assert len(args_ast) == 2 + assert len(keywords_ast) == 1 + assert keywords_ast[0].arg == "cool_boy" + assert isinstance(args_ast[0], ast.Name) + assert args_ast[0].id == "x" + assert isinstance(args_ast[1], ast.Constant) + assert args_ast[1].value == 1 + keyword_value_3 = keywords_ast[0].value + assert isinstance(keyword_value_3, ast.Name) + assert keyword_value_3.id == "Naroto" + + node.set_arg(ScopedValue.create_variable_value(2), 1) + assert isinstance(node.get_normalized_args().get("input2"), ScopedValue) + assert node.get_normalized_args().get("input2").value == 2 + ast_node: ast.Assign = node.get_ast() + assign_value: ast.Call = ast_node.value + args_ast = assign_value.args + assert args_ast[1].value == 2 + + args = node.get_args() + assert args == [ScopedValue.create_naming_value('x'), ScopedValue.create_variable_value(2)] + kwargs = node.get_kwargs() + assert kwargs == {"cool_boy": ScopedValue.create_naming_value('Naroto')} + + +def test_create_by_cell2(): + """ + Feature: Python api create_call_cell of Node of Rewrite. + Description: Call create_call_cell to create a node. + Expectation: Success. + """ + node = Node.create_call_cell(FakeCell2(), None, ['x'], 'new_conv', + [ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x"), + ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x"), + ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x")], + {"cool_boy": ScopedValue.create_naming_value('Naroto')}, 'new_conv') + assert node.get_normalized_args() == { + "a": ScopedValue.create_naming_value('x'), + "b": ScopedValue.create_naming_value('x'), + "d": ScopedValue.create_naming_value('x'), + "e": ScopedValue.create_naming_value('x'), + "args_4": ScopedValue.create_naming_value('x'), + "args_5": ScopedValue.create_naming_value('x'), + "cool_boy": ScopedValue.create_naming_value('Naroto'), + } + + +def test_create_by_cell3(): + """ + Feature: Python api create_call_cell of Node of Rewrite. + Description: Call create_call_cell to create a node. + Expectation: Success. + """ + node = Node.create_call_cell(FakeCell3(), None, ['x'], 'new_conv', + [ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x"), + ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x")], + {"h": ScopedValue.create_naming_value(1), "f": ScopedValue.create_naming_value(2), + "cool_boy": ScopedValue.create_naming_value('Naroto')}, 'new_conv') + assert node.get_normalized_args() == { + "a": ScopedValue.create_naming_value('x'), + "b": ScopedValue.create_naming_value('x'), + "args_2": ScopedValue.create_naming_value('x'), + "args_3": ScopedValue.create_naming_value('x'), + "f": ScopedValue.create_naming_value(2), + "h": ScopedValue.create_naming_value(1), + "cool_boy": ScopedValue.create_naming_value('Naroto'), + } diff --git a/tests/ut/python/rewrite/test_pattern_engine.py b/tests/ut/python/rewrite/test_pattern_engine.py new file mode 100644 index 00000000000..2c7ebaac219 --- /dev/null +++ b/tests/ut/python/rewrite/test_pattern_engine.py @@ -0,0 +1,665 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast +from collections import OrderedDict + +from mindspore.nn import Cell, Conv2d, BatchNorm2d, ReLU +from mindspore.ops import Add, AddN +from mindspore.rewrite import ScopedValue, Node, SymbolTree +from mindspore.rewrite import PatternEngine, PatternNode, Replacement, VarNode + + +def test_tree_pattern_match(): + """ + Feature: Python api PatternEngine. + Description: Construct a tree PatternEngine and apply it on a SymbolTree, check SymbolTree after PatternEngine + applied. + Expectation: Success. + """ + assert True + + +def test_leak_pattern_match(): + """ + Feature: Python api PatternEngine. + Description: Construct a leaked tree PatternEngine and apply it on a SymbolTree, check SymbolTree after + PatternEngine applied. + Expectation: Failure. + """ + assert True + + +class ChainNetwork(Cell): + def __init__(self): + super().__init__() + self.conv = Conv2d(16, 16, 3) + self.bn = BatchNorm2d(16) + self.relu1 = ReLU() + self.relu2 = ReLU() + self.relu3 = ReLU() + + def construct(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu1(x) + x = self.relu2(x) + x = self.relu3(x) + return x + + +def test_one_to_one_pattern(): + """ + Feature: Python api PatternEngine. + Description: Construct a one-to-one PatternEngine and apply it on a SymbolTree, check SymbolTree after PatternEngine + applied. + Expectation: Success. + """ + + class BnReplacement(Replacement): + def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]: + assert is_chain_pattern + assert pattern.type() == BatchNorm2d + bn_node: Node = matched.get(pattern.name()) + assert bn_node is not None + + conv = Conv2d(16, 16, 3) + conv_node = Node.create_call_cell(conv, ['x1'], bn_node.get_args(), bn_node.get_kwargs()) + return [conv_node] + + class BnReplace(PatternEngine): + def __init__(self): + super().__init__([BatchNorm2d], BnReplacement()) + + net = ChainNetwork() + stree = SymbolTree(net) + conv = stree.get_node("conv") + bn = stree.get_node("bn") + relu1 = stree.get_node("relu1") + construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast") + assert conv is not None + assert bn is not None + assert relu1 is not None + assert len(construct_ast.body) == 6 + assert len(stree.nodes()) == 7 + + bn_replace = BnReplace() + bn_replace.apply(stree) + + assert len(construct_ast.body) == 6 + assert len(stree.nodes()) == 7 + conv = stree.get_node("conv") + bn = stree.get_node("bn") + relu1 = stree.get_node("relu1") + new_conv = stree.get_node("x1") + assert conv is not None + assert bn is None + assert relu1 is not None + assert new_conv is not None + + # check conv topological order + assert len(conv.get_users()) == 1 + assert conv.get_users()[0] == new_conv + # check new_conv topological order + assert len(new_conv.get_inputs()) == 1 + assert new_conv.get_inputs()[0] == conv + assert len(new_conv.get_users()) == 1 + assert new_conv.get_users()[0] == relu1 + # check source code order + assert getattr(conv.get_handler(), "_next") == new_conv.get_handler() + assert getattr(new_conv.get_handler(), "_next") == relu1.get_handler() + assert getattr(relu1.get_handler(), "_prev") == new_conv.get_handler() + assert getattr(new_conv.get_handler(), "_prev") == conv.get_handler() + # # check arg edge + assert len(conv.get_targets()) == 1 + assert len(new_conv.get_args()) == 1 + assert conv.get_targets()[0] == new_conv.get_args()[0] + assert len(new_conv.get_targets()) == 1 + assert len(relu1.get_args()) == 1 + assert new_conv.get_targets()[0] == relu1.get_args()[0] + + +def test_one_to_multi_chain_pattern(): + """ + Feature: Python api PatternEngine. + Description: Construct a one-to-multi PatternEngine and apply it on a SymbolTree, check SymbolTree after + PatternEngine applied. + Expectation: Success. + """ + + class BnReplacement(Replacement): + def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]: + assert is_chain_pattern + assert pattern.type() == BatchNorm2d + bn_node: Node = matched.get(pattern.name()) + assert bn_node is not None + + # Replacement should ensure target is unique in result + # Replacement should ensure args and kwargs are well set by topological relation + conv1 = Conv2d(16, 16, 3) + conv_node1 = Node.create_call_cell(conv1, ['x1'], bn_node.get_args(), bn_node.get_kwargs()) + conv2 = Conv2d(16, 16, 5) + conv_node2 = Node.create_call_cell(conv2, ['x2'], [ScopedValue.create_naming_value('x1')]) + return [conv_node1, conv_node2] + + class BnReplace(PatternEngine): + def __init__(self): + super().__init__([BatchNorm2d], BnReplacement()) + + net = ChainNetwork() + stree = SymbolTree(net) + conv = stree.get_node("conv") + bn = stree.get_node("bn") + relu1 = stree.get_node("relu1") + construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast") + assert conv is not None + assert bn is not None + assert relu1 is not None + assert len(construct_ast.body) == 6 + assert len(stree.nodes()) == 7 + + bn_replace = BnReplace() + bn_replace.apply(stree) + + assert len(construct_ast.body) == 7 + assert len(stree.nodes()) == 8 + conv = stree.get_node("conv") + bn = stree.get_node("bn") + relu1 = stree.get_node("relu1") + new_conv1 = stree.get_node("x1") + new_conv2 = stree.get_node("x2") + assert conv is not None + assert bn is None + assert relu1 is not None + assert new_conv1 is not None + assert new_conv2 is not None + + # check conv topological order + assert len(conv.get_users()) == 1 + assert conv.get_users()[0] == new_conv1 + # check new_conv1 topological order + assert len(new_conv1.get_inputs()) == 1 + assert new_conv1.get_inputs()[0] == conv + assert len(new_conv1.get_users()) == 1 + assert new_conv1.get_users()[0] == new_conv2 + # check new_conv2 topological order + assert len(new_conv2.get_inputs()) == 1 + assert new_conv2.get_inputs()[0] == new_conv1 + assert len(new_conv2.get_users()) == 1 + assert new_conv2.get_users()[0] == relu1 + # check source code order + assert getattr(conv.get_handler(), "_next") == new_conv1.get_handler() + assert getattr(new_conv1.get_handler(), "_next") == new_conv2.get_handler() + assert getattr(new_conv2.get_handler(), "_next") == relu1.get_handler() + assert getattr(relu1.get_handler(), "_prev") == new_conv2.get_handler() + assert getattr(new_conv2.get_handler(), "_prev") == new_conv1.get_handler() + assert getattr(new_conv1.get_handler(), "_prev") == conv.get_handler() + # check arg edge + assert len(conv.get_targets()) == 1 + assert len(new_conv1.get_args()) == 1 + assert conv.get_targets()[0] == new_conv1.get_args()[0] + + assert len(new_conv1.get_targets()) == 1 + assert len(new_conv2.get_args()) == 1 + assert new_conv1.get_targets()[0] == new_conv2.get_args()[0] + + assert len(new_conv2.get_targets()) == 1 + assert len(relu1.get_args()) == 1 + assert new_conv2.get_targets()[0] == relu1.get_args()[0] + + +class TreeNetwork(Cell): + def __init__(self): + super().__init__() + self.conv1 = Conv2d(16, 16, 3) + self.conv2 = Conv2d(16, 16, 5) + self.add = Add() + self.relu = ReLU() + self.relu1 = ReLU() + self.relu2 = ReLU() + + def construct(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x) + x = self.add(x1, x2) + x = self.relu(x) + x1 = self.relu1(x) + x2 = self.relu2(x) + x = self.add(x1, x2) + return x + + +def test_tree_pattern(): + """ + Feature: Python api PatternEngine. + Description: Construct a multi-to-multi PatternEngine and apply it on a SymbolTree, check SymbolTree after + PatternEngine applied. + Expectation: Success. + """ + + class AddReluReplacement(Replacement): + def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]: + assert is_chain_pattern + assert pattern.type() == ReLU + relu_node: Node = matched.get(pattern.name()) + assert relu_node is not None + assert len(pattern.get_inputs()) == 1 + add_pattern = pattern.get_inputs()[0] + assert add_pattern.type() == Add + add_node: Node = matched.get(add_pattern.name()) + assert add_node is not None + assert not add_pattern.get_inputs() + + # can not use add_node here + new_add1 = Add() + new_add1_node = Node.create_call_cell(new_add1, ['new_add_1'], add_node.get_args(), add_node.get_kwargs()) + new_relu1 = ReLU() + new_relu1_node = Node.create_call_cell(new_relu1, ['new_relu_1'], + [ScopedValue.create_naming_value('new_add_1')]) + new_relu2 = ReLU() + new_relu2_node = Node.create_call_cell(new_relu2, ['new_relu_2'], + [ScopedValue.create_naming_value('new_add_1')]) + new_add2 = Add() + new_add2_node = Node.create_call_cell(new_add2, ['new_add_2'], + [ScopedValue.create_naming_value('new_relu_1'), + ScopedValue.create_naming_value('new_relu_2')]) + return [new_add1_node, new_relu1_node, new_relu2_node, new_add2_node] + + class AddReluPattern(PatternEngine): + def __init__(self): + super().__init__([Add, ReLU], AddReluReplacement()) + + net = TreeNetwork() + stree = SymbolTree(net) + conv1 = stree.get_node("conv1") + conv2 = stree.get_node("conv2") + add = stree.get_node("add") + relu = stree.get_node("relu") + relu1 = stree.get_node("relu1") + relu2 = stree.get_node("relu2") + assert conv1 is not None + assert conv2 is not None + assert add is not None + assert relu is not None + assert relu1 is not None + assert relu2 is not None + construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast") + assert len(construct_ast.body) == 8 + assert len(stree.nodes()) == 9 + + add_relu_pattern = AddReluPattern() + add_relu_pattern.apply(stree) + + assert len(construct_ast.body) == 10 + assert len(stree.nodes()) == 11 + conv1 = stree.get_node("conv1") + conv2 = stree.get_node("conv2") + add = stree.get_node("add") + relu = stree.get_node("relu") + relu1 = stree.get_node("relu1") + relu2 = stree.get_node("relu2") + new_add = stree.get_node("new_add") + new_relu = stree.get_node("new_relu") + new_relu_1 = stree.get_node("new_relu_1") + new_add_1 = stree.get_node("new_add_1") + + assert conv1 is not None + assert conv2 is not None + assert add is None + assert relu is None + assert relu1 is not None + assert relu2 is not None + assert new_add is not None + assert new_relu is not None + assert new_relu_1 is not None + assert new_add_1 is not None + + # check conv1 topological order + assert len(conv1.get_users()) == 1 + assert conv1.get_users()[0] == new_add + # check conv2 topological order + assert len(conv2.get_users()) == 1 + assert conv2.get_users()[0] == new_add + # check new_add topological order + assert len(new_add.get_inputs()) == 2 + assert new_add.get_inputs()[0] == conv1 + assert new_add.get_inputs()[1] == conv2 + assert len(new_add.get_users()) == 2 + assert new_add.get_users()[0] == new_relu + assert new_add.get_users()[1] == new_relu_1 + # check new_relu topological order + assert len(new_relu.get_inputs()) == 1 + assert new_relu.get_inputs()[0] == new_add + assert len(new_relu.get_users()) == 1 + assert new_relu.get_users()[0] == new_add_1 + # check new_relu_1 topological order + assert len(new_relu_1.get_inputs()) == 1 + assert new_relu_1.get_inputs()[0] == new_add + assert len(new_relu_1.get_users()) == 1 + assert new_relu_1.get_users()[0] == new_add_1 + # check new_add_1 topological order + assert len(new_add_1.get_inputs()) == 2 + assert new_add_1.get_inputs()[0] == new_relu_1 + assert new_add_1.get_inputs()[1] == new_relu + assert len(new_add_1.get_users()) == 2 + assert new_add_1.get_users()[0] == relu1 + assert new_add_1.get_users()[1] == relu2 + # check source code order + assert getattr(conv1.get_handler(), "_next") == conv2.get_handler() + assert getattr(conv2.get_handler(), "_next") == new_add.get_handler() + assert getattr(new_add.get_handler(), "_next") == new_relu.get_handler() + assert getattr(new_relu.get_handler(), "_next") == new_relu_1.get_handler() + assert getattr(new_relu_1.get_handler(), "_next") == new_add_1.get_handler() + assert getattr(new_add_1.get_handler(), "_next") == relu1.get_handler() + assert getattr(relu1.get_handler(), "_prev") == new_add_1.get_handler() + assert getattr(new_add_1.get_handler(), "_prev") == new_relu_1.get_handler() + assert getattr(new_relu_1.get_handler(), "_prev") == new_relu.get_handler() + assert getattr(new_relu.get_handler(), "_prev") == new_add.get_handler() + assert getattr(new_add.get_handler(), "_prev") == conv2.get_handler() + assert getattr(conv2.get_handler(), "_prev") == conv1.get_handler() + # check arg edge + assert len(conv1.get_targets()) == 1 + assert len(conv2.get_targets()) == 1 + assert len(new_add.get_args()) == 2 + assert conv1.get_targets()[0] == new_add.get_args()[0] + assert conv2.get_targets()[0] == new_add.get_args()[1] + + assert len(new_add.get_targets()) == 1 + assert len(new_relu.get_args()) == 1 + assert len(new_relu_1.get_args()) == 1 + assert new_add.get_targets()[0] == new_relu.get_args()[0] + assert new_add.get_targets()[0] == new_relu_1.get_args()[0] + + assert len(new_relu.get_targets()) == 1 + assert len(new_relu_1.get_targets()) == 1 + assert len(new_add_1.get_args()) == 2 + assert new_relu.get_targets()[0] == new_add_1.get_args()[1] + assert new_relu_1.get_targets()[0] == new_add_1.get_args()[0] + + assert len(new_add_1.get_targets()) == 1 + assert len(relu1.get_args()) == 1 + assert len(relu2.get_args()) == 1 + assert new_add_1.get_targets()[0] == relu1.get_args()[0] + assert new_add_1.get_targets()[0] == relu2.get_args()[0] + + +class TreeNetwork2(Cell): + def __init__(self): + super().__init__() + self.conv1 = Conv2d(16, 16, 1) + self.conv2 = Conv2d(16, 16, 3) + self.add1 = AddN() + self.add2 = AddN() + self.relu = ReLU() + + def construct(self, x, y, z): + x = self.conv1(x) + y = self.conv2(y) + z = self.add1(x, y, z) + z = self.add2(x, y, z) + z = self.relu(z) + return z + + +class MultiInputPattern(PatternEngine): + class MultiInputReplacement(Replacement): + def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]: + assert not is_chain_pattern + assert pattern.type() == AddN + addn2_node: Node = matched.get(pattern.name()) + assert addn2_node is not None + assert len(pattern.get_inputs()) == 3 + conv1_pn = pattern.get_inputs()[0] + conv2_pn = pattern.get_inputs()[1] + addn1_pn = pattern.get_inputs()[2] + assert conv1_pn.type() == Conv2d + assert conv2_pn.type() == Conv2d + assert addn1_pn.type() == AddN + conv1_node: Node = matched.get(conv1_pn.name()) + conv2_node: Node = matched.get(conv2_pn.name()) + addn1_node: Node = matched.get(addn1_pn.name()) + assert conv1_node is not None + assert conv2_node is not None + assert addn1_node is not None + assert len(conv1_node.get_inputs()) == 1 + assert len(conv2_node.get_inputs()) == 1 + assert len(addn1_node.get_inputs()) == 3 + arg1 = conv1_node.get_args()[0] + arg2 = conv2_node.get_args()[0] + arg3 = addn1_node.get_args()[2] + + # can not use add_node here + new_add1 = Add() + new_add1_node = Node.create_call_cell(new_add1, ['new_add1'], [arg1, arg2]) + new_add2 = Add() + new_add2_node = Node.create_call_cell(new_add2, ['new_add2'], [ScopedValue.create_naming_value('new_add1'), + arg3]) + return [new_add1_node, new_add2_node] + + def __init__(self): + conv1_pn = PatternNode("conv1", Conv2d) + conv2_pn = PatternNode("conv2", Conv2d) + addn1_pn = PatternNode("addn1", AddN) + addn2_pn = PatternNode("addn2", AddN) + conv1_pn.set_inputs([VarNode()]) + conv2_pn.set_inputs([VarNode()]) + addn1_pn.set_inputs([conv1_pn, conv2_pn, VarNode()]) + addn2_pn.set_inputs([conv1_pn, conv2_pn, addn1_pn]) + super().__init__(addn2_pn, MultiInputPattern.MultiInputReplacement()) + + +def test_multi_input_to_multi_pattern_tree_pattern(): + """ + Feature: Python api PatternEngine. + Description: Construct a multi-to-multi PatternEngine and apply it on a SymbolTree, check SymbolTree after + PatternEngine applied. + Expectation: Success. + """ + + net = TreeNetwork2() + stree = SymbolTree(net) + conv1 = stree.get_node("conv1") + conv2 = stree.get_node("conv2") + add1 = stree.get_node("add1") + add2 = stree.get_node("add2") + relu = stree.get_node("relu") + assert conv1 is not None + assert conv2 is not None + assert add1 is not None + assert add2 is not None + assert relu is not None + construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast") + assert len(construct_ast.body) == 6 + assert len(stree.nodes()) == 9 + + multi_input_pattern = MultiInputPattern() + multi_input_pattern.apply(stree) + + assert len(construct_ast.body) == 4 + assert len(stree.nodes()) == 7 + conv1 = stree.get_node("conv1") + conv2 = stree.get_node("conv2") + add1 = stree.get_node("add1") + add2 = stree.get_node("add2") + relu = stree.get_node("relu") + new_add1 = stree.get_node("new_add1") + new_add2 = stree.get_node("new_add2") + inputx = stree.get_node("input_x") + inputy = stree.get_node("input_y") + inputz = stree.get_node("input_z") + assert conv1 is None + assert conv2 is None + assert add1 is None + assert add2 is None + assert relu is not None + assert new_add1 is not None + assert new_add2 is not None + assert inputx is not None + assert inputy is not None + assert inputz is not None + + # check inputx topological order + assert len(inputx.get_users()) == 1 + assert inputx.get_users()[0] == new_add1 + # check inputy topological order + assert len(inputy.get_users()) == 1 + assert inputy.get_users()[0] == new_add1 + # check inputz topological order + assert len(inputz.get_users()) == 1 + assert inputz.get_users()[0] == new_add2 + # check new_add1 topological order + assert len(new_add1.get_inputs()) == 2 + assert new_add1.get_inputs()[0] == inputx + assert new_add1.get_inputs()[1] == inputy + assert len(new_add1.get_users()) == 1 + assert new_add1.get_users()[0] == new_add2 + # check new_add2 topological order + assert len(new_add2.get_inputs()) == 2 + assert new_add2.get_inputs()[0] == new_add1 + assert new_add2.get_inputs()[1] == inputz + assert len(new_add2.get_users()) == 1 + assert new_add2.get_users()[0] == relu + # check relu topological order + assert len(relu.get_inputs()) == 1 + assert relu.get_inputs()[0] == new_add2 + # check source code order + assert getattr(inputz.get_handler(), "_next") == new_add1.get_handler() + assert getattr(new_add1.get_handler(), "_next") == new_add2.get_handler() + assert getattr(new_add2.get_handler(), "_next") == relu.get_handler() + assert getattr(relu.get_handler(), "_prev") == new_add2.get_handler() + assert getattr(new_add2.get_handler(), "_prev") == new_add1.get_handler() + assert getattr(new_add1.get_handler(), "_prev") == inputz.get_handler() + # check arg edge + assert len(inputx.get_targets()) == 1 + assert len(inputy.get_targets()) == 1 + assert len(new_add1.get_args()) == 2 + assert inputx.get_targets()[0] == new_add1.get_args()[0] + assert inputy.get_targets()[0] == new_add1.get_args()[1] + + assert len(inputz.get_targets()) == 1 + assert len(new_add1.get_targets()) == 1 + assert len(new_add2.get_args()) == 2 + assert new_add1.get_targets()[0] == new_add2.get_args()[0] + assert inputz.get_targets()[0] == new_add2.get_args()[1] + + assert len(new_add2.get_targets()) == 1 + assert len(relu.get_args()) == 1 + assert new_add2.get_targets()[0] == relu.get_args()[0] + + +class TreeNetwork3(Cell): + def __init__(self): + super().__init__() + self.conv1 = Conv2d(16, 16, 1) + self.conv2 = Conv2d(16, 16, 3) + self.add1 = AddN() + self.add2 = AddN() + self.relu = ReLU() + + def construct(self, x): + y = self.conv1(x) + z = self.conv2(x) + x = self.add1(y, z, x) + x = self.add2(y, z, x) + x = self.relu(x) + return x + + +def test_one_input_to_multi_pattern_tree_pattern(): + """ + Feature: Python api PatternEngine. + Description: Construct a multi-to-multi PatternEngine and apply it on a SymbolTree, check SymbolTree after + PatternEngine applied. + Expectation: Success. + """ + + net = TreeNetwork3() + stree = SymbolTree(net) + conv1 = stree.get_node("conv1") + conv2 = stree.get_node("conv2") + add1 = stree.get_node("add1") + add2 = stree.get_node("add2") + relu = stree.get_node("relu") + assert conv1 is not None + assert conv2 is not None + assert add1 is not None + assert add2 is not None + assert relu is not None + construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast") + assert len(construct_ast.body) == 6 + assert len(stree.nodes()) == 7 + + multi_input_pattern = MultiInputPattern() + multi_input_pattern.apply(stree) + + assert len(construct_ast.body) == 4 + assert len(stree.nodes()) == 5 + conv1 = stree.get_node("conv1") + conv2 = stree.get_node("conv2") + add1 = stree.get_node("add1") + add2 = stree.get_node("add2") + relu = stree.get_node("relu") + new_add1 = stree.get_node("new_add1") + new_add2 = stree.get_node("new_add2") + inputx = stree.get_node("input_x") + assert conv1 is None + assert conv2 is None + assert add1 is None + assert add2 is None + assert relu is not None + assert new_add1 is not None + assert new_add2 is not None + assert inputx is not None + + # check inputx topological order + assert len(inputx.get_users()) == 2 + assert inputx.get_users()[0] == new_add1 + assert inputx.get_users()[1] == new_add2 + # check new_add1 topological order + assert len(new_add1.get_inputs()) == 2 + assert new_add1.get_inputs()[0] == inputx + assert new_add1.get_inputs()[1] == inputx + assert len(new_add1.get_users()) == 1 + assert new_add1.get_users()[0] == new_add2 + # check new_add2 topological order + assert len(new_add2.get_inputs()) == 2 + assert new_add2.get_inputs()[0] == new_add1 + assert new_add2.get_inputs()[1] == inputx + assert len(new_add2.get_users()) == 1 + assert new_add2.get_users()[0] == relu + # check relu topological order + assert len(relu.get_inputs()) == 1 + assert relu.get_inputs()[0] == new_add2 + # check source code order + assert getattr(inputx.get_handler(), "_next") == new_add1.get_handler() + assert getattr(new_add1.get_handler(), "_next") == new_add2.get_handler() + assert getattr(new_add2.get_handler(), "_next") == relu.get_handler() + assert getattr(relu.get_handler(), "_prev") == new_add2.get_handler() + assert getattr(new_add2.get_handler(), "_prev") == new_add1.get_handler() + assert getattr(new_add1.get_handler(), "_prev") == inputx.get_handler() + # check arg edge + assert len(inputx.get_targets()) == 1 + assert len(new_add1.get_args()) == 2 + assert inputx.get_targets()[0] == new_add1.get_args()[0] + assert inputx.get_targets()[0] == new_add1.get_args()[1] + + assert len(inputx.get_targets()) == 1 + assert len(new_add1.get_targets()) == 1 + assert len(new_add2.get_args()) == 2 + assert new_add1.get_targets()[0] == new_add2.get_args()[0] + assert inputx.get_targets()[0] == new_add2.get_args()[1] + + assert len(new_add2.get_targets()) == 1 + assert len(relu.get_args()) == 1 + assert new_add2.get_targets()[0] == relu.get_args()[0] diff --git a/tests/ut/python/rewrite/test_symbol_tree.py b/tests/ut/python/rewrite/test_symbol_tree.py new file mode 100644 index 00000000000..edc87d614d6 --- /dev/null +++ b/tests/ut/python/rewrite/test_symbol_tree.py @@ -0,0 +1,388 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast +import inspect + +from mindspore.nn import Cell, Conv2d, BatchNorm2d, ReLU +from mindspore.ops import Add +from mindspore.rewrite import ScopedValue, ValueType, NodeType +from mindspore.rewrite import Node as NodeApi +from mindspore.rewrite.symbol_tree import SymbolTree +from mindspore.rewrite.node import Node + + +class Network(Cell): + def __init__(self): + super().__init__() + self.conv = Conv2d(16, 16, 3) + self.bn = BatchNorm2d(16) + self.relu1 = ReLU() + self.relu2 = ReLU() + self.relu3 = ReLU() + + def construct(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu1(x) + x = self.relu2(x) + x = self.relu3(x) + return x + + +def create_symbol_tree(): + net = Network() + source = inspect.getsource(type(net)) + ast_root = ast.parse(source) + ast_module = ast_root + assert isinstance(ast_root, ast.Module) + ast_class = ast_module.body[0] + assert isinstance(ast_class, ast.ClassDef) + ast_init_func = ast_class.body[0] + assert isinstance(ast_init_func, ast.FunctionDef) + ast_construct_func = ast_class.body[1] + assert isinstance(ast_construct_func, ast.FunctionDef) + ast_conv = ast_construct_func.body[0] + ast_bn = ast_construct_func.body[1] + ast_relu1 = ast_construct_func.body[2] + ast_relu2 = ast_construct_func.body[3] + ast_relu3 = ast_construct_func.body[4] + ast_return = ast_construct_func.body[5] + stree = SymbolTree(net, ast_module) + stree.set_class_ast(ast_class) + stree.set_init_func_ast(ast_init_func) + stree.set_ast_root(ast_construct_func) + stree.append_input_node("x") + conv_node = Node.create_call_cell(net.conv, ast_conv, [ScopedValue.create_naming_value("x")], + ScopedValue.create_naming_value("conv", "self"), + [ScopedValue.create_naming_value("x")], + {}, + "conv") + stree.append_origin_field(conv_node) + bn_node = Node.create_call_cell(net.bn, ast_bn, [ScopedValue.create_naming_value("x")], + ScopedValue.create_naming_value("bn", "self"), + [ScopedValue.create_naming_value("x")], {}, + "bn") + bn_node = stree.append_origin_field(bn_node) + relu1_node = Node.create_call_cell(net.relu1, ast_relu1, [ScopedValue.create_naming_value("x")], + ScopedValue.create_naming_value("relu1", "self"), + [ScopedValue.create_naming_value("x")], + {}, "relu1") + relu1_node = stree.append_origin_field(relu1_node) + relu2_node = Node.create_call_cell(net.relu2, ast_relu2, [ScopedValue.create_naming_value("x")], + ScopedValue.create_naming_value("relu2", "self"), + [ScopedValue.create_naming_value("x")], + {}, "relu2") + relu2_node = stree.append_origin_field(relu2_node) + relu3_node = Node.create_call_cell(net.relu3, ast_relu3, [ScopedValue.create_naming_value("x")], + ScopedValue.create_naming_value("relu3", "self"), + [ScopedValue.create_naming_value("x")], + {}, "relu3") + stree.append_origin_field(relu3_node) + node_return = Node.create_output_node(ast_return, ["x"]) + stree.append_origin_field(node_return) + return stree, bn_node, relu1_node, relu2_node + + +def test_insert_node(): + """ + Feature: Python api insert_node of SymbolTree of Rewrite. + Description: Call insert_node to insert a node into SymbolTree. + Expectation: Success. + """ + stree, _, relu1, relu2 = create_symbol_tree() + construct_ast: ast.FunctionDef = getattr(stree, "_root_ast") + providers = getattr(getattr(stree, "_topo_mgr"), "_target_provider") + consumers = getattr(getattr(stree, "_topo_mgr"), "_target_consumer") + providers_len = len(providers) + consumers_len = len(consumers) + assert len(stree.nodes()) == 7 + assert len(construct_ast.body) == 6 + assert len(relu1.get_targets()) == 1 + assert len(relu2.get_normalized_args().values()) == 1 + assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0] + input1 = 1 + node = Node.create_call_cell(Add(), None, ['x'], 'new_conv', + [ScopedValue.create_naming_value('x'), ScopedValue.create_variable_value(input1)], {}, + 'new_conv') + position = stree.before(relu2) + node = stree.insert_node(position, node) + # check nodes size + assert len(stree.nodes()) == 8 + # check args + assert len(relu2.get_normalized_args().values()) == 1 + assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0] + assert len(node.get_normalized_args().values()) == 2 + assert list(node.get_normalized_args().values())[0] == ScopedValue.create_naming_value('x') + assert list(node.get_normalized_args().values())[1].type == ValueType.IntValue + # check provider + assert len(providers) == providers_len + 1 + assert len(node.get_targets()) == 1 + assert providers.get(node.get_targets()[0])[0] == node + assert providers.get(node.get_targets()[0])[1] == 0 + # check consumer + assert len(consumers) == consumers_len + 1 + assert consumers.get(list(node.get_normalized_args().values())[1]) is not None + # check inputs + assert len(relu2.get_inputs()) == 1 + assert relu2.get_inputs()[0] == relu1 + assert len(node.get_inputs()) == 1 + assert node.get_inputs()[0].get_node_type() == NodeType.Input + # check ast + node_ast = node.get_ast() + assert isinstance(node_ast, ast.Assign) + args = node_ast.value.args + assert isinstance(args, list) + assert len(args) == 2 + assert isinstance(args[0], ast.Name) + assert isinstance(args[1], ast.Constant) + assert len(construct_ast.body) == 7 + + +def test_set_node_arg(): + """ + Feature: Python api set_node_arg of SymbolTree of Rewrite. + Description: Call set_node_arg to change topological-order of a node. + Expectation: Success. + """ + stree, bn, relu1, relu2 = create_symbol_tree() + assert len(stree.nodes()) == 7 + assert len(bn.get_targets()) == 1 + bn_output = bn.get_targets()[0] + # check bn topological order + assert len(stree.get_node_users(bn)) == 1 + assert stree.get_node_users(bn)[0][0] == relu1 + # check relu1 topological order + assert len(stree.get_node_inputs(relu1)) == 1 + assert stree.get_node_inputs(relu1)[0] == bn + assert len(stree.get_node_users(relu1)) == 1 + assert stree.get_node_users(relu1)[0][0] == relu2 + # check relu2 topological order + assert len(stree.get_node_inputs(relu2)) == 1 + assert stree.get_node_inputs(relu2)[0] == relu1 + # check relu1 and relu2 edge + assert len(relu1.get_targets()) == 1 + assert len(relu2.get_normalized_args().values()) == 1 + assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0] + + stree.set_node_arg(relu2, 0, bn_output) + # check bn topological order + assert len(stree.get_node_users(bn)) == 2 + assert stree.get_node_users(bn)[0][0] == relu1 + assert stree.get_node_users(bn)[1][0] == relu2 + # check relu1 topological order + assert len(stree.get_node_inputs(relu1)) == 1 + assert stree.get_node_inputs(relu1)[0] == bn + assert len(stree.get_node_users(relu1)) == 0 + # check relu2 topological order + assert len(stree.get_node_inputs(relu2)) == 1 + assert stree.get_node_inputs(relu2)[0] == bn + # check bn and relu2 edge + assert len(relu1.get_targets()) == 1 + assert len(relu2.get_normalized_args().values()) == 1 + assert bn_output == list(relu2.get_normalized_args().values())[0] + # check ast + node_ast = relu2.get_ast() + assert isinstance(node_ast, ast.Assign) + args = node_ast.value.args + assert isinstance(args, list) + assert len(args) == 1 + assert isinstance(args[0], ast.Name) + assert args[0].id == bn_output.value + + +def test_set_node_arg_by_node(): + """ + Feature: Python api set_node_arg_by_node of SymbolTree of Rewrite. + Description: Call set_node_arg_by_node to change topological-order of a node. + Expectation: Success. + """ + stree, bn, relu1, relu2 = create_symbol_tree() + assert len(stree.nodes()) == 7 + assert len(bn.get_targets()) == 1 + bn_output = bn.get_targets()[0] + # check bn topological order + assert len(stree.get_node_users(bn)) == 1 + assert stree.get_node_users(bn)[0][0] == relu1 + # check relu1 topological order + assert len(stree.get_node_inputs(relu1)) == 1 + assert stree.get_node_inputs(relu1)[0] == bn + assert len(stree.get_node_users(relu1)) == 1 + assert stree.get_node_users(relu1)[0][0] == relu2 + # check relu2 topological order + assert len(stree.get_node_inputs(relu2)) == 1 + assert stree.get_node_inputs(relu2)[0] == relu1 + # check relu1 and relu2 edge + assert len(relu1.get_targets()) == 1 + assert len(relu2.get_normalized_args().values()) == 1 + assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0] + + stree.set_node_arg_by_node(relu2, 0, bn) + # check bn topological order + assert len(stree.get_node_users(bn)) == 2 + assert stree.get_node_users(bn)[0][0] == relu1 + assert stree.get_node_users(bn)[1][0] == relu2 + # check relu1 topological order + assert len(stree.get_node_inputs(relu1)) == 1 + assert stree.get_node_inputs(relu1)[0] == bn + assert len(stree.get_node_users(relu1)) == 0 + # check relu2 topological order + assert len(stree.get_node_inputs(relu2)) == 1 + assert stree.get_node_inputs(relu2)[0] == bn + # check bn and relu2 edge + assert len(relu1.get_targets()) == 1 + assert len(relu2.get_normalized_args().values()) == 1 + assert bn_output == list(relu2.get_normalized_args().values())[0] + # check ast + node_ast = relu2.get_ast() + assert isinstance(node_ast, ast.Assign) + args = node_ast.value.args + assert isinstance(args, list) + assert len(args) == 1 + assert isinstance(args[0], ast.Name) + assert args[0].id == bn_output.value + + +def test_erase_succeed(): + """ + Feature: Python api erase_node of SymbolTree of Rewrite. + Description: Call erase_node to erase a node from SymbolTree. + Expectation: Success. + """ + stree, bn, relu1, relu2 = create_symbol_tree() + construct_ast: ast.FunctionDef = getattr(stree, "_root_ast") + providers = getattr(getattr(stree, "_topo_mgr"), "_target_provider") + providers_len = len(providers) + assert len(stree.nodes()) == 7 + assert len(construct_ast.body) == 6 + + stree.set_node_arg_by_node(relu2, 0, bn) + stree.erase_node(relu1) + + assert len(stree.nodes()) == 6 + assert len(providers) == providers_len - 1 + assert len(construct_ast.body) == 5 + + +def test_erase_failed(): + """ + Feature: Python api erase_node of SymbolTree of Rewrite. + Description: Call erase_node to erase a node from SymbolTree which is not isolated. + Expectation: Failure. + """ + stree, _, relu1, _ = create_symbol_tree() + catched_error = False + try: + stree.erase_node(relu1) + except RuntimeError: + catched_error = True + assert catched_error + + +def test_replace_one_to_one(): + """ + Feature: Python api replace of SymbolTree of Rewrite. + Description: Call replace to replace an origin node to a new node. + Expectation: Success. + """ + stree, bn, relu1, relu2 = create_symbol_tree() + construct_ast: ast.FunctionDef = getattr(stree, "_root_ast") + assert len(construct_ast.body) == 6 + assert len(stree.nodes()) == 7 + + new_conv = Conv2d(16, 16, 5) + new_conv_node = NodeApi.create_call_cell(new_conv, [ScopedValue.create_naming_value("new_conv")], + bn.get_targets()).get_handler() + new_conv_node = stree.replace(relu1, [new_conv_node]) + assert len(stree.nodes()) == 7 + # check ast + assert len(construct_ast.body) == 6 + node_ast: ast.Assign = construct_ast.body[2] + func_ast: ast.Attribute = node_ast.value.func + assert func_ast.attr == new_conv_node.get_name() + # check bn topological order + assert len(stree.get_node_users(bn)) == 1 + assert stree.get_node_users(bn)[0][0] == new_conv_node + # check new_conv_node topological order + assert len(stree.get_node_inputs(new_conv_node)) == 1 + assert stree.get_node_inputs(new_conv_node)[0] == bn + assert len(stree.get_node_users(new_conv_node)) == 1 + assert stree.get_node_users(new_conv_node)[0][0] == relu2 + # check relu2 topological order + assert len(stree.get_node_inputs(relu2)) == 1 + assert stree.get_node_inputs(relu2)[0] == new_conv_node + # check arg edge + assert len(bn.get_targets()) == 1 + assert len(new_conv_node.get_normalized_args().values()) == 1 + assert bn.get_targets()[0] == list(new_conv_node.get_normalized_args().values())[0] + assert len(new_conv_node.get_targets()) == 1 + assert len(relu2.get_normalized_args().values()) == 1 + assert new_conv_node.get_targets()[0] == list(relu2.get_normalized_args().values())[0] + + +def test_replace_one_to_multi(): + """ + Feature: Python api replace of SymbolTree of Rewrite. + Description: Call replace to replace an origin node to a new node-tree. + Expectation: Success. + """ + stree, bn, relu1, relu2 = create_symbol_tree() + construct_ast: ast.FunctionDef = getattr(stree, "_root_ast") + assert len(construct_ast.body) == 6 + assert len(stree.nodes()) == 7 + + new_conv_node = NodeApi.create_call_cell(Conv2d(16, 16, 5), [ScopedValue.create_naming_value("new_conv")], + bn.get_targets()).get_handler() + new_relu_node = NodeApi.create_call_cell(ReLU(), [ScopedValue.create_naming_value("new_relu")], + new_conv_node.get_targets()).get_handler() + new_relu_node = stree.replace(relu1, [new_relu_node, new_conv_node]) + new_conv_node = new_relu_node.get_inputs()[0] + + assert len(stree.nodes()) == 8 + # check ast + assert len(construct_ast.body) == 7 + new_conv_ast: ast.Assign = construct_ast.body[2] + new_conv_func_ast: ast.Attribute = new_conv_ast.value.func + assert new_conv_func_ast.attr == new_conv_node.get_name() + new_relu_ast: ast.Assign = construct_ast.body[3] + new_relu_func_ast: ast.Attribute = new_relu_ast.value.func + assert new_relu_func_ast.attr == new_relu_node.get_name() + # check bn topological order + assert len(stree.get_node_users(bn)) == 1 + assert stree.get_node_users(bn)[0][0] == new_conv_node + # check new_conv_node topological order + assert len(stree.get_node_inputs(new_conv_node)) == 1 + assert stree.get_node_inputs(new_conv_node)[0] == bn + assert len(stree.get_node_users(new_conv_node)) == 1 + assert stree.get_node_users(new_conv_node)[0][0] == new_relu_node + # check new_relu_node topological order + assert len(stree.get_node_inputs(new_relu_node)) == 1 + assert stree.get_node_inputs(new_relu_node)[0] == new_conv_node + assert len(stree.get_node_users(new_relu_node)) == 1 + assert stree.get_node_users(new_relu_node)[0][0] == relu2 + # check relu2 topological order + assert len(stree.get_node_inputs(relu2)) == 1 + assert stree.get_node_inputs(relu2)[0] == new_relu_node + # check arg edge + assert len(bn.get_targets()) == 1 + assert len(new_conv_node.get_normalized_args().values()) == 1 + assert bn.get_targets()[0] == list(new_conv_node.get_normalized_args().values())[0] + + assert len(new_conv_node.get_targets()) == 1 + assert len(new_relu_node.get_normalized_args().values()) == 1 + assert new_conv_node.get_targets()[0] == list(new_relu_node.get_normalized_args().values())[0] + + assert len(new_relu_node.get_targets()) == 1 + assert len(relu2.get_normalized_args().values()) == 1 + assert new_relu_node.get_targets()[0] == list(relu2.get_normalized_args().values())[0] diff --git a/tests/ut/python/runtest.sh b/tests/ut/python/runtest.sh index 5cf3f4481a2..7958a0fa823 100755 --- a/tests/ut/python/runtest.sh +++ b/tests/ut/python/runtest.sh @@ -119,6 +119,12 @@ else if [ ${RET} -ne 0 ]; then exit ${RET} fi + + pytest -s $CURRPATH/rewrite/*.py + RET=$? + if [ ${RET} -ne 0 ]; then + exit ${RET} + fi fi RET=$?