add mindpore rewrite

This commit is contained in:
hangangqiang 2022-03-08 14:12:51 +08:00
parent 65c881a246
commit de196f3a25
36 changed files with 6456 additions and 0 deletions

View File

@ -38,6 +38,11 @@
"mindspore/mindspore/python/mindspore/train/serialization.py" "protected-access"
"mindspore/mindspore/python/mindspore/train/model.py" "protected-access"
"mindspore/mindspore/python/mindspore/log.py" "protected-access"
"mindspore/mindspore/python/mindspore/rewrite/api/node.py" "protected-access"
"mindspore/mindspore/python/mindspore/rewrite/node.py" "protected-access"
"mindspore/mindspore/python/mindspore/rewrite/parser_register.py" "protected-access"
"mindspore/mindspore/python/mindspore/rewrite/api/pattern_engine.py" "protected-access"
"mindspore/mindspore/python/mindspore/rewrite/symbol_tree.py" "inconsistent-return-statements"
"mindspore/model_zoo/official/cv" "missing-docstring"
"mindspore/model_zoo/official/cv" "c-extension-no-member"
"mindspore/model_zoo/official/nlp/bert_thor/src/bert_model.py" "redefined-outer-name"
@ -101,6 +106,10 @@
"mindspore/tests/ut/python/pynative_mode" "unused-variable"
"mindspore/tests/ut/python/pynative_mode/test_stop_gradient.py" "redefined-outer-name"
"mindspore/tests/ut/python/pynative_mode/test_stop_gradient.py" "super-init-not-called"
"mindspore/tests/ut/python/rewrite/test_flatten_recursive_stmt.py" "consider-using-ternary"
"mindspore/tests/ut/python/rewrite/test_node.py" "syntax-error"
"mindspore/tests/ut/python/rewrite/test_node.py" "protected-access"
"mindspore/tests/ut/python/rewrite/test_symbol_tree.py" "len-as-condition"
"mindspore/tests/ut/python/test_log.py" "possibly-unused-variable"
"mindspore/tests/ut/python/test_log.py" "protected-access"
"mindspore/tests/ut/python/train/summary/test_summary_collector.py" "protected-access"

View File

@ -306,6 +306,7 @@ install(
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/communication
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/profiler
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/compression
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/rewrite
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/run_check
DESTINATION ${INSTALL_PY_DIR}
COMPONENT mindspore

View File

@ -26,6 +26,7 @@ from .context import GRAPH_MODE, PYNATIVE_MODE, set_context, get_context, set_au
get_auto_parallel_context, reset_auto_parallel_context, ParallelMode, set_ps_context, \
get_ps_context, reset_ps_context, set_fl_context, get_fl_context
from .version import __version__
from .rewrite import *
__all__ = ["run_check"]
@ -34,3 +35,4 @@ __all__.extend(common.__all__)
__all__.extend(train.__all__)
__all__.extend(log.__all__)
__all__.extend(context.__all__)
__all__.extend(rewrite.__all__)

View File

@ -0,0 +1,32 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
MindSpore Rewrite module.
This is an experimental python package that is subject to change or deletion.
"""
from .parsers.module_parser import g_module_parser
from .parsers.class_def_parser import g_classdef_parser
from .parsers.function_def_parser import g_functiondef_parser
from .parsers.arguments_parser import g_arguments_parser
from .parsers.assign_parser import g_assign_parser
from .parsers.return_parser import g_return_parser
from .api.scoped_value import ScopedValue, ValueType
from .api.symbol_tree import SymbolTree
from .api.node import Node
from .api.node_type import NodeType
from .api.pattern_engine import PatternEngine, PatternNode, VarNode, Replacement
__all__ = ["SymbolTree", "Node", "NodeType", "ScopedValue", "ValueType", "PatternEngine", "PatternNode", "VarNode",
"Replacement"]

View File

@ -0,0 +1,17 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
MindSpore Rewrite api.
"""

View File

@ -0,0 +1,302 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Rewrite module api: Node."""
from typing import Union, Optional
from mindspore.nn import Cell
from ..node import Node as NodeImpl
from ..symbol_tree import SymbolTree as SymbolTreeImpl
from .node_type import NodeType
from .scoped_value import ScopedValue
class Node:
"""
Node is a data structure represents a source code line in network.
For the most part, Node represents an operator invoking in forward which could be an instance of `Cell`, an instance
of `Primitive` or a callable method.
`NodeImpl` mentioned below is implementation of `Node` which is not an interface of Rewrite. Rewrite recommend
invoking specific create method of `Node` to instantiate an instance of Node such as `create_call_cell` rather than
invoking constructor of `Node` directly, so don't care about what is `NodeImpl` and use its instance just as a
handler.
Args:
node (NodeImpl): A handler of `NodeImpl`.
"""
def __init__(self, node: NodeImpl):
self._node = node
def get_handler(self) -> NodeImpl:
"""
Get handler of node implementation.
Returns:
An instance of `NodeImpl`.
"""
return self._node
@staticmethod
def create_call_cell(cell: Cell, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None,
kwargs: {str: ScopedValue}=None, name: str = "") -> 'Node':
"""
Create a node. Only support create from a `Cell` now.
A node is corresponding to source code like:
.. code-block::
`targets` = self.`name`(*`args`, **`kwargs`)
Args:
cell (Cell): Cell-operator of this forward-layer.
targets (list[ScopedValue]): Indicate output names. Used as targets of an assign statement in source code.
Rewrite will check and ensure the uniqueness of each target while node being inserted.
args (list[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in
source code. Default is None indicate the `cell` has no args inputs. Rewrite will check and ensure the
uniqueness of each arg while node being inserted.
kwargs (dict{str: ScopedValue}): Indicate keyword input names. Used as kwargs of a call expression of an
assign statement in source code. Default is None indicate the `cell` has no kwargs inputs. Rewrite will
check and ensure the uniqueness of each kwarg while node being inserted.
name (str): Indicate the name of node. Used as field name in source code. Default is None. Rewrite will
generate name from `targets` when name is None. Rewrite will check and ensure the uniqueness of `name`
while node being inserted.
Returns:
An instance of `Node`.
Raises:
RuntimeError: If `cell` is not a `Cell`.
RuntimeError: If `targets` is None.
RuntimeError: If target in `targets` is not a `NamingValue`-`ScopedValue`.
RuntimeError: If arg in `args` is not a `NamingValue`-`ScopedValue` or a `CustomObjValue`-`ScopedValue`.
RuntimeError: If value of kwarg in `kwargs` is not a `NamingValue`-`ScopedValue` or a
`CustomObjValue`-`ScopedValue`.
"""
return Node(NodeImpl.create_call_cell(cell, None, targets, ScopedValue.create_naming_value(name, "self"), args,
kwargs, name))
def get_prev(self) -> 'Node':
"""
Get previous node of current node in source code order.
Returns:
An instance of `Node` as previous node.
"""
return Node(self._node.get_prev())
def get_next(self) -> 'Node':
"""
Get next node of current node in source code order.
Returns:
An instance of `Node` as next node.
"""
return Node(self._node.get_next())
def get_inputs(self) -> ['Node']:
"""
Get input nodes of current node in topological order.
Returns:
A list of instances of `Node` as input nodes.
"""
return [Node(node_impl) for node_impl in self._node.get_inputs()]
def get_users(self) -> ['Node']:
"""
Get output nodes of current node in topological order.
Returns:
A list of nodes represents users.
"""
belong_symbol_tree: SymbolTreeImpl = self._node.get_belong_symbol_tree()
if belong_symbol_tree is None:
return []
unique_results = []
for node_user in belong_symbol_tree.get_node_users(self._node):
node = node_user[0]
if node not in unique_results:
unique_results.append(node)
return [Node(node_impl) for node_impl in unique_results]
def set_arg(self, index: int, arg: Union[ScopedValue, str]):
"""
Set argument of current node.
Args:
index (int): Indicate which input being modified.
arg (Union[ScopedValue, str]): New argument to been set.
Raises:
RuntimeError: If `index` is out of range.
RuntimeError: If `arg` a `NamingValue`-`ScopedValue` or a `CustomObjValue`-`ScopedValue` when `arg` is an
`ScopedValue`.
"""
belong_symbol_tree: SymbolTreeImpl = self._node.get_belong_symbol_tree()
if belong_symbol_tree is None:
self._node.set_arg(arg, index)
else:
belong_symbol_tree.set_node_arg(self._node, index, arg)
def set_arg_by_node(self, arg_idx: int, src_node: 'Node', out_idx: Optional[int] = None):
"""
Set argument of current node by another `Node`.
Args:
arg_idx (int): Indicate which input being modified.
src_node (Node): A `Node` as new input. Can be a node or name of node.
out_idx (int, optional): Indicate which output of `src_node` as new input of current node. Default is None
which means use first output of `src_node` as new input.
Raises:
RuntimeError: If `src_node` is not belong to current `SymbolTree`.
RuntimeError: If current node and `src_node` is not belong to same `SymbolTree`.
RuntimeError: If `arg_idx` is out of range.
RuntimeError: If `out_idx` is out of range.
RuntimeError: If `src_node` has multi-outputs while `out_idx` is None or `out_idx` is not offered.
"""
belong_symbol_tree: SymbolTreeImpl = self._node.get_belong_symbol_tree()
if belong_symbol_tree is None:
self._node.set_arg_by_node(arg_idx, src_node._node, out_idx)
else:
belong_symbol_tree.set_node_arg_by_node(self._node, arg_idx, src_node._node, out_idx)
def get_targets(self) -> [ScopedValue]:
"""
Get targets of current node.
- When node_type of current node is `CallCell`, `CallPrimitive`, `CallMethod` or `Tree`, `targets` are strings
represents invoke result of the cell-op or primitive-op or function-call which are corresponding to targets of
ast.Assign.
- When node_type of current node is Input, `targets` should have only one element which is a string represents
parameter of function.
- When node_type of current node is `Python` or `Output`, `targets` are don't-care.
Returns:
A list of instances of ScopedValue as targets of node.
"""
return self._node.get_targets()
def get_name(self) -> str:
"""
Get the name of current node.
When node has been inserted into `SymbolTree`, the name of node should be unique in `SymbolTree`.
Returns:
A string as name of node.
"""
return self._node.get_name()
def get_node_type(self) -> NodeType:
"""
Get the node_type of current node.
Returns:
A NodeType as node_type of node.
"""
return self._node.get_node_type()
def get_instance_type(self) -> type:
"""
Get the instance_type of current node.
- When node_type of current node is `CallCell`, instance_type is type of cell-op.
- When node_type of current node is `CallPrimitive`, instance_type is type of primitive-op.
- When node_type of current node is `Tree`, instance_type is type of network-cell.
- When node_type of current node is `Python`, `Input`, `Output` or `CallMethod`, instance_type should be
NoneType.
Returns:
A type object represents corresponding instance type of current node.
"""
return self._node.get_instance_type()
def get_instance(self):
"""
Get the instance of current node.
- When node_type of current node is `CallCell`, instance is an instance of Cell.
- When node_type of current node is `CallPrimitive`, instance is an instance of primitive.
- When node_type of current node is `Tree`, instance is an instance of network-cell.
- When node_type of current node is `Python`, `Input`, `Output` or `CallMethod`, instance should be None.
Returns:
A object represents corresponding instance of current node.
"""
return self._node.get_instance()
def get_args(self) -> [ScopedValue]:
"""
Get the arguments of current node.
- When node_type of current node is `CallCell`, `CallPrimitive` or `Tree`, arguments are corresponding to args
of ast.Call which represents arguments to invoke forward method of cell-op or primitive-op.
- When node_type of current node is `Input`, arguments represents default-value of argument of function.
- When node_type of current node is `Output`, arguments represents return values.
- When node_type of current node is `Python`, arguments are don't-care.
Returns:
A list of instances of `ScopedValue`.
"""
return self._node.get_args()
def get_kwargs(self) -> {str: ScopedValue}:
"""
Get the keyword arguments of current node.
- When node_type of current node is `CallCell`, `CallPrimitive` or `Tree`, keyword arguments are corresponding
to kwargs of ast.Call which represents arguments to invoke forward method of cell-op or primitive-op.
- When node_type of current node is `Python`, `Input` or `Output`, keyword arguments are don't-care.
Returns:
A dict of str to instance of `ScopedValue`.
"""
return self._node.get_kwargs()
def set_attribute(self, key: str, value):
"""
Set attribute of current node.
Args:
key (str): Key of attribute.
value (object): Value of attribute.
"""
self._node.set_attribute(key, value)
def get_attributes(self) -> {str: object}:
"""
Get all attributes of current node.
Returns:
A dict of str to instance of object as attributes.
"""
return self._node.get_attributes()
def get_attribute(self, key: str):
"""
Get attribute of current node by key.
Returns:
A object as attribute.
"""
return self._node.get_attribute(key)
def __eq__(self, other: 'Node'):
return self._node == other._node

View File

@ -0,0 +1,43 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Rewrite module api: NodeType."""
from enum import Enum
class NodeType(Enum):
"""
`NodeType` represents type of `Node`.
- Unknown: Not inited NodeType.
- CallCell: `CallCell` node represents invoking cell-op in forward method.
- CallPrimitive: `CallPrimitive` node represents invoking primitive-op in forward method.
- CallMethod: `CallMethod` node represents invoking of method in forward method which can not be mapped to
cell-op or primitive-op in MindSpore.
- Python: `Python` node holds unsupported-ast-node or unnecessary-to-parse-ast-node.
- Input: `Input` node represents input of `SymbolTree` corresponding to arguments of forward method.
- Output: `Output` node represents output of SymbolTree corresponding to return statement of forward method.
- Tree: `Tree` node represents sub-network invoking in forward method.
"""
Unknown = 0
# Compute node type
CallCell = 1
CallPrimitive = 2
CallMethod = 3
# Other node type
Python = 4
Input = 5
Output = 6
Tree = 7

View File

@ -0,0 +1,413 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""PatternEngine for modifying SymbolTree by pattern."""
from collections import OrderedDict
from typing import Tuple, Union, List, Type
import abc
from mindspore.nn import Cell
from mindspore import log as logger
from .node_type import NodeType
from .node import Node
from .symbol_tree import SymbolTree
class PatternNode:
"""
`PatternNode` is defined as a node while defining pattern.
Args:
pattern_node_name (str): Name of current node.
match_type (Type): A type represents what type would be matched of current node.
inputs (list[PatternNode]): Input nodes of current node.
"""
def __init__(self, pattern_node_name: str, match_type: Type = Type[None], inputs: ['PatternNode'] = None):
self._name = pattern_node_name
self._type = match_type
if inputs is None:
self._inputs = []
else:
self._inputs = inputs
@staticmethod
def from_node(node: Node) -> 'PatternNode':
"""
Create a `PatternNode` from `node`.
Args:
node (Node): Input rewrite node.
Returns:
A `PatternNode` created from `node`.
"""
pattern_node: PatternNode = PatternNode(node.get_targets()[0])
if node.get_node_type() is NodeType.CallCell:
pattern_node._type = node.get_instance_type()
return pattern_node
@staticmethod
def create_pattern_from_node(node: Node) -> 'PatternNode':
"""
Create a Pattern from `node` with its inputs.
Args:
node (Node): Input rewrite node.
Returns:
A `PatternNode` as root of pattern created from rewrite node.
"""
pattern_node: PatternNode = PatternNode.from_node(node)
inputs = []
for node_input in node.get_inputs():
inputs.append(PatternNode.create_pattern_from_node(node_input))
pattern_node._inputs = inputs
return pattern_node
@staticmethod
def create_pattern_from_list(type_list: []) -> 'PatternNode':
"""
Create a Pattern from a cell type list.
Args:
type_list (list[type]): Input cell type list.
Returns:
A `PatternNode` as root of pattern created from cell type list.
"""
last_node = None
for i, cell_type in enumerate(type_list):
cur_node: PatternNode = PatternNode(str(i) + "-" + str(cell_type), cell_type, [])
if last_node is not None:
cur_node._inputs = [last_node]
else:
cur_node._inputs = []
last_node = cur_node
return last_node
def add_input(self, node):
"""
Add an input for current `PatternNode`.
Args:
node (PatternNode): Cell type as an input.
"""
self._inputs.append(node)
def set_inputs(self, inputs):
"""
Set inputs for current `PatternNode`.
Args:
inputs (list[PatternNode]) : Inputs to be set as inputs of current `PatternNode`.
"""
self._inputs = inputs
def match(self, node: Node) -> bool:
"""
Check if current `PatternNode` can match with `node`.
Args:
node (Node) : A rewrite node to be match.
"""
return self._type == node.get_instance_type()
def get_inputs(self):
"""
Getter of inputs.
"""
return self._inputs
def name(self) -> str:
"""
Getter of name.
"""
return self._name
def type(self):
"""
Getter of type.
"""
return self._type
class VarNode(PatternNode):
"""
VarNode is a subclass of `PatternNode` whose `match` method is always return True.
"""
def __init__(self):
super(VarNode, self).__init__("placeholder", Cell, [])
def match(self, node: Node) -> bool:
return node is not None and node.get_handler() is not None
class Replacement(abc.ABC):
"""
Interface of replacement function.
"""
@abc.abstractmethod
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
"""
Interface define for creating replacement nodes from matched result.
Note:
Return value will be delivered into replace api of `SymbolTree` as argument, return value should follow
restraint of parameter `new_nodes` of `replace` api if `SymbolTree`. See detail in docstring of `replace`
api of `SymbolTree`.
Args:
pattern (PatternNode): A `PatternNode` represents root node of current pattern.
is_chain_pattern (bool): A bool indicated if pattern is a chain pattern or a tree pattern.
matched (OrderedDict): An OrderedDict map from pattern_node name to node represents matched result.
Returns:
A list of instance of `Node` as replacement nodes.
"""
raise NotImplementedError
def __call__(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
return self.build(pattern, is_chain_pattern, matched)
class PatternEngine:
"""
`PatternEngine` is defined how to transform a `SymbolTree` by `PattenNode`.
Args:
pattern (Union[PatternNode, List]): An instance of `PatternNode` or a cell-type-list to construct `PatternNode`
as root of a pattern.
replacement (callable): A callable define how to generate new_node.
"""
def __init__(self, pattern: Union[PatternNode, List], replacement: Replacement = None):
if isinstance(pattern, PatternNode):
self._is_chain = False
self._replacement: Replacement = replacement
self._pattern: PatternNode = pattern
elif isinstance(pattern, list):
self._is_chain = True
self._replacement: Replacement = replacement
self._pattern: PatternNode = PatternNode.create_pattern_from_list(pattern)
else:
raise RuntimeError("Unsupported pattern define")
def pattern(self) -> PatternNode:
"""
Getter of pattern.
"""
return self._pattern
@staticmethod
def _multi_to_multi_replace(stree: SymbolTree, old_root: Node, matched_dict: OrderedDict,
new_nodes: [Node]) -> Node:
"""
Replace multi-nodes in `stree` by another list of nodes.
Note:
Call replace api of `SymbolTree`, so parameter `new_nodes` has same restraint with parameter `new_nodes` of
`replace` api if `SymbolTree`. See detail in docstring of `replace` api of `SymbolTree`.
Args:
stree (SymbolTree): A `SymbolTree` which replacement will apply on.
old_root (Node): A `Node` represents root of original nodes.
matched_dict (OrderedDict): An instance of OrderedDict as match result, where key is the pattern name, value
is the matched node.
new_nodes (list[Node]): A list of instance of Node as replacement.
"""
to_erase_list = matched_dict.values()
# keep all old nodes' inputs
inputs_dict = {}
for node in to_erase_list:
inputs_dict[node.get_name()] = (node.get_inputs())
# call replace of SymbolTree
new_root = stree.replace(old_root, new_nodes)
# replace only support one-to-one replace or one-to-multi replace, we need to erase nodes except
# cur_node manually
queue: [Node] = [old_root]
while queue:
cur_node: Node = queue.pop(0)
if cur_node in to_erase_list:
if cur_node.get_users():
# if cur_node is depended on by other node, skip now.
# cur_node will be push into queue and be erased later
continue
if stree.get_node(cur_node.get_name()) is not None:
# cur_node is not erased before
stree.erase_node(cur_node)
queue.extend(inputs_dict.get(cur_node.get_name()))
return new_root
def apply(self, stree: SymbolTree) -> bool:
"""
Apply current pattern to a `SymbolTree`.
Note:
Sub-tree node will be supported in the near feature.
Args:
stree (SymbolTree): A `SymbolTree` to be transformed.
Returns:
A bool represents if `stree` been changed.
Raises:
RuntimeError: If `SymbolTree` has no return node.
"""
root: Node = stree.get_return_node()
if root is None:
raise RuntimeError("SymbolTree should be inited and has return node")
changed = False
# IR match
queue: [Node] = [root]
# Why need visited: we don't need or should not to visit same node multi-times because pre-visited node may
# already been erased from SymbolTree.
# When will we visit same node multi-times:
# a
# / \
# / \
# b c
# | |
# | d
# \ /
# \ /
# e
# 1. Visit e, e does not match pattern, add b, d to queue.
# 2. Visit b, b does not match pattern, add a to queue.
# 3. Visit d, d does not match pattern, add c to queue.
# 4. Visit a, a matches pattern and erased from SymbolTree, add xx to queue.
# 5. Visit c, d does not match pattern, add a to queue.
# At step 5, a is visited at second time but a is erased from SymbolTree at step 4.
visited: [Node] = []
while queue:
cur_node: Node = queue.pop(0)
if cur_node is None: # Because inputs of node is allowed to be None in replacement.
continue
if cur_node in visited:
continue
visited.append(cur_node)
matched, matched_dict = self._match(self._pattern, cur_node)
# not matched
if not matched or not PatternEngine._check_match(self._pattern, matched_dict):
if cur_node is not None:
queue.extend(cur_node.get_inputs())
continue
# matched
new_nodes: [Node] = []
if self._replacement is not None:
new_nodes: [Node] = self._replacement(self._pattern, self._is_chain, matched_dict)
if not new_nodes: # if replacement is empty, do nothing
queue.extend(cur_node.get_inputs())
else: # replace cur_node with new_nodes
changed = True
root = PatternEngine._multi_to_multi_replace(stree, cur_node, matched_dict, new_nodes)
queue.append(root)
return changed
@staticmethod
def _merge_ordered_dict(dict1: OrderedDict, dict2: OrderedDict) -> OrderedDict:
"""
A static util method to merge two OrderedDict.
Args:
dict1 (OrderedDict): First dict to be merged.
dict2 (OrderedDict): Second dict to be merged.
Returns:
Merged OrderedDict.
"""
merged = dict1.copy()
merged.update(dict2)
return merged
def _match(self, pattern: PatternNode, node: Node) -> Tuple[bool, OrderedDict]:
"""
Match `pattern` with a `node` with all inputs of the `pattern`.
Args:
pattern (PatternNode): Pattern to be match.
node (Node): Node to be match.
Returns:
A bool value to indicate if matched.
An instance of OrderedDict as match result, where key is the pattern name, value is the matched node.
"""
# Don't iterate into subgraph node, pattern should not be matched across sub-tree
if node.get_node_type() != NodeType.CallCell and node.get_node_type() != NodeType.Input:
logger.debug("Pattern match failed: node(%s) is not a cell", str(node))
return False, OrderedDict()
if not pattern.match(node):
logger.debug("Pattern match failed: node(%s)'s type is %s while pattern type is %s", str(node),
node.get_instance_type(), pattern.type())
return False, OrderedDict()
if isinstance(pattern, VarNode):
return True, OrderedDict()
pattern_inputs = pattern.get_inputs()
cur_inputs = node.get_inputs()
input_num = len(pattern_inputs)
if input_num == 0:
return True, OrderedDict({pattern.name(): node})
if input_num != len(cur_inputs):
logger.debug("Pattern match failed: node(%s)'s has %d inputs while pattern has %d inputs", str(node),
len(node.get_inputs()), input_num)
return False, OrderedDict()
result = OrderedDict()
for i in range(0, input_num):
is_matched, tmp_result = self._match(pattern_inputs[i], cur_inputs[i])
if not is_matched:
return False, OrderedDict()
result = PatternEngine._merge_ordered_dict(result, tmp_result)
result[pattern.name()] = node
return True, result
@staticmethod
def _check_match(pattern: PatternNode, match_dict: OrderedDict) -> bool:
"""
Check if matched result is a leak result.
A leak result means that the result is matched the `pattern`, but some nodes in result which is not
corresponding to root of pattern have outputs used by nodes outside of result.
Args:
pattern (PatternNode): A `PatternNode` represents pattern to be match.
match_dict (OrderedDict): A OrderedDict represents matched result.
Returns:
A bool value to indicate if matched result leaked.
"""
matched_nodes = match_dict.values()
for key in match_dict:
if key == pattern.name():
continue
node: Node = match_dict[key]
for output in node.get_users():
if output not in matched_nodes:
logger.debug("Check match failed, pattern leaked")
return False
return True

View File

@ -0,0 +1,152 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Rewrite module api: ValueType and ScopedValue."""
from enum import Enum
from typing import Optional
class ValueType(Enum):
"""
ValueType represents type of `ScopedValue`.
- A `NamingValue` represents a reference to another variable.
- A `CustomObjValue` represents an instance of custom class or an object whose type is out of range of base-type
and container-type of ValueType.
"""
# base type
StringValue = 0
IntValue = 1
FloatValue = 2
# container type
TupleValue = 20
ListValue = 21
DictValue = 22
# other type
NamingValue = 40
CustomObjValue = 41
class ScopedValue:
"""
`ScopedValue` represents a value with its full-scope.
`ScopedValue` is used to express: a left-value such as target of an assign statement, or a callable object such as
func of a call statement, or a right-value such as args and kwargs of an assign statement.
Args:
arg_type (ValueType): A `ValueType` represents type of current value.
scope (str): A string represents scope of current value. Take "self.var1" as an example, `scope` of this
var1 is "self".
value: A handler represents value of current value. The type of value is corresponding to `arg_type`.
"""
def __init__(self, arg_type: ValueType, scope: str = "", value=None):
self.type = arg_type
self.scope = scope
self.value = value
@classmethod
def create_variable_value(cls, value) -> Optional['ScopedValue']:
"""
Create `ScopedValue` from a variable.
`ScopedValue`'s type is determined by type of value. `ScopedValue`'s scope is empty.
Args:
value: The value to be converted to `ScopedValue`.
Returns:
An instance of `ScopedValue`.
"""
if isinstance(value, int):
return cls(ValueType.IntValue, "", value)
if isinstance(value, float):
return cls(ValueType.FloatValue, "", value)
if isinstance(value, str):
return cls(ValueType.StringValue, "", value)
if isinstance(value, tuple):
return cls(ValueType.TupleValue, "",
tuple(cls.create_variable_value(single_value) for single_value in value))
if isinstance(value, list):
return cls(ValueType.ListValue, "", list(cls.create_variable_value(single_value) for single_value in value))
if isinstance(value, dict):
for key, _ in value.items():
if not isinstance(key, str):
raise TypeError("key should be str, got: ", type(key))
return cls(ValueType.DictValue, "",
dict((key, cls.create_variable_value(single_value)) for key, single_value in value.items()))
return cls(ValueType.CustomObjValue, "", value)
@classmethod
def create_naming_value(cls, name: str, scope: str = "") -> 'ScopedValue':
"""
Create a naming `ScopedValue`. A `NamingValue` represents a reference to another variable.
Args:
name: (str): A string represents the identifier of another variable.
scope: (str): A string represents the scope of another variable.
Returns:
An instance of `ScopedValue`.
"""
return cls(ValueType.NamingValue, scope, name)
@staticmethod
def create_name_values(names: [str], scopes: [str] = None) -> ['ScopedValue']:
"""
Create a list of naming `ScopedValue`.
Args:
names: (list[str]): A list of string represents names of referenced variables.
scopes: (list[str]): A list of string represents scopes of referenced variables.
Returns:
An list of instance of `ScopedValue`.
Raise:
RuntimeError: If the length of names is not equal to the length of scopes when scopes are not None.
"""
if scopes is not None:
if len(names) != len(scopes):
raise RuntimeError("Length of names should be equal to length of scopes")
result = []
for index, name in enumerate(names):
if scopes is not None:
scope = scopes[index]
else:
scope = ""
result.append(ScopedValue.create_naming_value(name, scope))
return result
def __str__(self):
if self.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
return str(self.value)
if self.type == ValueType.NamingValue:
return f"{self.scope}.{self.value}" if self.scope else str(self.value)
if self.type == ValueType.CustomObjValue:
return f"CustomObj: {str(self.value)}"
return f"Illegal ValueType: {str(self.type)}"
def __eq__(self, other):
if id(self) == id(other):
return True
return self.type == other.type and self.scope == other.scope and self.value == other.value
def __repr__(self):
return str(self)
def __hash__(self):
return hash((self.type, self.scope, self.value))

View File

@ -0,0 +1,227 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Rewrite module api: SymbolTree."""
from typing import Optional
from mindspore.nn import Cell
from .node import Node
from ..symbol_tree_builder import SymbolTreeBuilder
from ..symbol_tree import SymbolTree as SymbolTreeImpl
class SymbolTree:
"""
A `SymbolTree` usually corresponding to forward method of a network.
Args:
network (Cell): Network to be rewritten. Only support `Cell`-type-network now.
Raises:
RuntimeError: If `network` is not a Cell.
RuntimeError: If there is any unsupported ast node type while parsing or optimizing.
"""
def __init__(self, network: Cell):
self._symbol_tree: SymbolTreeImpl = SymbolTreeBuilder(network).build()
def get_handler(self) -> SymbolTreeImpl:
"""
Get handler of `SymbolTree` implementation.
Returns:
An instance of `SymbolTree`.
"""
return self._symbol_tree
def nodes(self) -> {}:
"""
Get all nodes of corresponding network.
Returns:
A dict mapping from name of node to node.
"""
return [Node(node_impl) for node_impl in self._symbol_tree.nodes(unfold_subtree=True)]
def get_node(self, node_name: str) -> Optional[Node]:
"""
Get node by `node_name`.
Args:
node_name (str): A string represents name of node.
Returns:
An instance of node if find else None.
"""
node_impl = self._symbol_tree.get_node(node_name)
if node_impl is None:
return None
return Node(node_impl)
def get_return_node(self) -> Node:
"""
Get return node of current `SymbolTree`.
A `SymbolTree` can and should have one return node corresponding to return statement of forward method of
network.
Returns:
An instance of node represents return node.
"""
ret = self._symbol_tree.get_return_node()
if ret is None:
raise RuntimeError("SymbolTree is not well inited, can not find return node.")
return Node(ret)
def before(self, node: Node):
"""
Get insert position before input `node`.
`Position` is used to indicate where to insert node, it indicates position in source code rather than position
in topological order. We don't need to care about what `Position` is, just treat it as a handler and use it as
an arguments of `insert` api of `SymbolTree`.
Args:
node (Node): Indicate the position before which node. Can be a node or name of node.
Returns:
A `Position` to indicate where to insert node.
Raises:
RuntimeError: if `node` is not a Node or a string.
"""
return self._symbol_tree.before(node.get_handler())
def after(self, node: Node):
"""
Get insert position after input `node`.
`Position` is used to indicate where to insert node, it indicates position in source code rather than position
in topological order. We don't need to care about what `Position` is, just treat it as a handler and use it as
an arguments of `insert` api of `SymbolTree`.
Args:
node (Node): Indicate the position after which node. Can be a node or name of node.
Returns:
A `Position` to indicate where to insert node.
Raises:
RuntimeError: If `node` is not a Node.
"""
return self._symbol_tree.after(node.get_handler())
def insert(self, position, node: Node) -> Node:
"""
Insert a `node` into `SymbolTree` at `position`.
`position` is obtained from `before` api or `after` api of `SymbolTree`.
Args:
position (Position): Indicate where to insert `node`.
node (Node): An instance of Node to be inserted.
Returns:
An instance of Node being inserted. `node` could be changed while calling this method for uniqueness and
custom-object in args or kwargs.
Raises:
RuntimeError: If `position` is not belong to current `SymbolTree`.
"""
return Node(self._symbol_tree.insert_node(position, node.get_handler()))
def erase_node(self, node: Node) -> Optional[Node]:
"""
Erase a `node` from rewrite. Can only erase a node not being depended on.
Args:
node (Node): A `Node` to be erased. Can be a node or name of node.
Returns:
An instance of `Node` being erased if node is in `SymbolTree` else None.
Raises:
RuntimeError: If `node` is not a `Node`.
"""
return Node(self._symbol_tree.erase_node(node.get_handler()))
def replace(self, old_node: Node, new_nodes: [Node]) -> Node:
"""
Replace `old_node` with a node_tree.
Note:
1. Replace support one-to-one replacement or one-to-multi replacement. If you need multi-to-multi
replacement, please refer to `PatternEngine`.
2. When applying one-to-multi replacement, Rewrite will insert all `new_nodes` into symbol_tree.
3. Caller should maintain arguments and targets of nodes intra sub-tree for specifying topological relation
intra sub-tree.
4. Caller should maintain arguments of input nodes of sub-tree and for specifying topological relation of
inputs of sub-tree.
5. Rewrite will maintain arguments of prepend node of sub-tree for specifying topological relation of
outputs of sub-tree.
6. Rewrite will maintain all inputs of nodes after replace `new_nodes` into `SymbolTree`.
Args:
old_node (Node): Node to be replaced.
new_nodes (list[Node]): Nodes of the node_tree to replace in.
Returns:
An instance of Node represents root of node_tree been replaced in.
Raises:
RuntimeError: Old node is isolated.
"""
nodes_impl = [node.get_handler() for node in new_nodes]
return Node(self._symbol_tree.replace(old_node.get_handler(), nodes_impl))
def set_output(self, index: int, return_value: str) -> Node:
"""
Set return value of network.
Args:
index (int): Indicate which output being modified.
return_value (str): New return value to been set.
Returns:
Return node of current rewrite.
Raises:
RuntimeError: If `index` is out of range.
"""
return Node(self._symbol_tree.set_output(return_value, index))
def dump(self):
"""
Dump graph to console.
"""
self._symbol_tree.dump()
def get_code(self) -> str:
"""
Get source code of modified network.
Returns:
A str represents source code of modified network.
"""
return self._symbol_tree.get_code()
def get_network(self) -> Cell:
"""
Get modified network.
Returns:
A network object.
"""
return self._symbol_tree.get_network()

View File

@ -0,0 +1,292 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Ast utils for create or update ast node."""
from typing import Optional
import ast
from .api.scoped_value import ScopedValue, ValueType
class AstModifier(ast.NodeTransformer):
"""Ast utils for create or update ast node."""
@staticmethod
def erase_ast_from_function(ast_func: ast.FunctionDef, to_erase: ast.AST) -> bool:
"""
Erase ast node from ast.FunctionDef.
Args:
ast_func (ast.FunctionDef): From which to search to_erase-node and erase.
to_erase (ast.AST): Node to be erased.
Returns:
A bool if to_erase-node been found and been erased.
"""
for body in ast_func.body:
if id(body) == id(to_erase):
ast_func.body.remove(body)
return True
return False
@staticmethod
def insert_assign_to_function(ast_func: ast.FunctionDef, targets: [ScopedValue], expr: ScopedValue,
args: [ScopedValue] = None, kwargs: {str, ScopedValue}=None,
index_ast: Optional[ast.AST] = None, insert_before=True) -> ast.AST:
"""
Insert an ast.Assign into an ast.FunctionDef.
Args:
ast_func (ast.FunctionDef): Where new ast.Assign to be inserted into.
targets ([ScopedValue]): Targets of ast.Assign.
expr (ScopedValue): Func of ast.Call which is value of new ast.Assign.
args ([ScopedValue]): Args of ast.Call which is value of new ast.Assign.
kwargs ({str, ScopedValue}): Kwargs of ast.Call which is value of new ast.Assign.
index_ast (Optional[ast.AST]): An ast_node indicates a position in 'ast_func' where new ast.Assign node to
be inserted into. Default is None which means append new ast.Assign into
'ast_func'.
insert_before (bool): A bool indicates at before or at after of 'index_ast' where new ast.Assign node to be
inserted into. Only valid when 'index_ast' is not None. Default is True which means
inserting new ast.Assign before 'index_ast'.
Returns:
An instance of ast.Assign which has been inserted into 'ast_func'.
Raises:
RuntimeError: If 'index_ast' is not contained in 'ast_func'.
"""
assign = AstModifier.create_call_assign(targets, expr, args, kwargs)
return AstModifier.insert_assign_ast_to_function(ast_func, assign, index_ast, insert_before)
@staticmethod
def insert_assign_ast_to_function(ast_func: ast.FunctionDef, ast_assign: ast.Assign,
index_ast: Optional[ast.AST] = None, insert_before=True) -> ast.AST:
"""
Insert an ast.Assign into an ast.FunctionDef.
Args:
ast_func (ast.FunctionDef): Where new ast.Assign to be inserted into.
ast_assign (ast.Assign): An instance of ast.Assign to be inserted in.
index_ast (Optional[ast.AST]): An ast_node indicates a position in 'ast_func' where new ast.Assign node to
be inserted into. Default is None which means append new ast.Assign to
'ast_func'.
insert_before (bool): A bool indicates at before or at after of 'index_ast' where new ast.Assign node to be
inserted into. Only valid when 'index_ast' is not None. Default is True which means
inserting new ast.Assign before 'index_ast'.
Returns:
An instance of ast.Assign which has been inserted into 'ast_func'.
Raises:
RuntimeError: If 'index_ast' is not contained in 'ast_func'.
"""
if index_ast is None:
ast_func.body.append(ast_assign)
ast.fix_missing_locations(ast_func)
return ast_assign
for index in range(0, len(ast_func.body)):
if id(ast_func.body[index]) == id(index_ast):
if insert_before:
ast_func.body.insert(index, ast_assign)
else:
ast_func.body.insert(index + 1, ast_assign)
ast.fix_missing_locations(ast_func)
return ast_assign
raise RuntimeError("index_ast is not contained in ast_func")
@staticmethod
def append_global_vars_expr_to_init(init_func: ast.FunctionDef, targets: [ScopedValue],
field: str) -> ast.AST:
"""
Append an ast.Assign to an ast.FunctionDef which is function named "__init__" in network. Value of new
ast.Assign is an ast.Call represents get an object from global_vars dict.
While user inserting a custom op, the instance of new custom op is saved in a dict named global_vars. Rewrite
need to get the custom op instance from global_vars in new "__init__" function of network:
self.var1 = global_vars.get("var1")
Args:
init_func (ast.FunctionDef): An instance of ast.FunctionDef which is "__init__" function of network.
targets ([ScopedValue]): Targets of ast.Assign.
field (str): A string represents name of new custom op field.
Returns:
An instance of ast.Assign which has been appended to 'init_func'.
"""
return AstModifier.insert_assign_to_function(init_func, targets=targets,
args=[ScopedValue.create_variable_value(field)],
expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"))
@staticmethod
def create_call_assign(targets: [ScopedValue], expr: ScopedValue, args: [ScopedValue],
kwargs: {str, ScopedValue}) -> ast.Assign:
"""
Create an instance of ast.Assign whose value must ba a ast.Call.
Args:
targets ([ScopedValue]): Targets of ast.Assign.
expr (ScopedValue): Func of ast.Call which is value of new ast.Assign.
args ([ScopedValue]): Args of ast.Call which is value of new ast.Assign.
kwargs ({str, ScopedValue}): Kwargs of ast.Call which is value of new ast.Assign.
Returns:
An instance of ast.Assign.
Raises:
RuntimeError: If 'targets' is None.
RuntimeError: If value_type of element of 'targets' is not ValueType.NamingValue.
RuntimeError: If length of 'targets' is not 1. Multi-targets will be support in the future.
"""
if targets is None or len(targets) != 1:
raise RuntimeError("Only support one target in insert_cell_to_init now")
if targets[0].type != ValueType.NamingValue:
raise RuntimeError("Target must be a right-value, got: ", targets[0])
if targets[0].scope:
ast_target = ast.Attribute(ast.Name(targets[0].scope, ast.Load()), targets[0].value, ast.Store())
else:
ast_target = ast.Name(targets[0].value, ast.Store())
call = AstModifier.create_call(expr, args, kwargs)
result = ast.Assign(targets=[ast_target], value=call)
ast.fix_missing_locations(result)
return result
@staticmethod
def create_call(expr: ScopedValue, args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None) -> ast.Call:
"""
Create an instance of ast.Call.
Args:
expr (ScopedValue): Func of ast.Call.
args ([ScopedValue]): Args of ast.Call.
kwargs ({str, ScopedValue}): Kwargs of ast.Call.
Returns:
An instance of ast.Call.
Raises:
RuntimeError: If value_type of 'expr' is ValueType.CustomObjValue.
RuntimeError: If value_type of 'expr' is not ValueType.NamingValue.
RuntimeError: If value_type of element of 'args' is ValueType.CustomObjValue.
RuntimeError: If value_type of value of 'kwargs' is ValueType.CustomObjValue.
TypeError: If expr is not an instance of ScopedValue.
RuntimeError: If element of 'args' is not an instance of ScopedValue.
RuntimeError: If value of 'kwargs' is not an instance of ScopedValue.
"""
if not isinstance(expr, ScopedValue):
raise TypeError("expr should be ScopedValue, got: ", type(expr))
if expr.type == ValueType.CustomObjValue:
raise RuntimeError("Please handle custom-object first")
if expr.type != ValueType.NamingValue:
raise RuntimeError("Expr must not be a constant, because constant can not been called: ", expr.type)
if expr.scope:
ast_func = ast.Attribute(ast.Name(expr.scope, ast.Load()), expr.value, ast.Store())
else:
ast_func = ast.Name(expr.value, ast.Store())
ast_args = []
if args is not None:
for arg in args:
if not isinstance(arg, ScopedValue):
raise TypeError("arg should be ScopedValue, got: ", type(arg))
if arg.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
if arg.scope:
raise RuntimeError("arg.scope should be empty")
ast_args.append(ast.Constant(value=arg.value, kind=None))
elif arg.type == ValueType.NamingValue:
if arg.scope:
ast_args.append(ast.Attribute(ast.Name(arg.scope, ast.Load()), arg.value, ast.Store()))
else:
ast_args.append(ast.Name(arg.value, ast.Store()))
else:
raise RuntimeError("Please handle custom-object first")
keywords = []
if kwargs is not None:
for arg, value in kwargs.items():
if not isinstance(value, ScopedValue):
raise TypeError("value should be ScopedValue, got: ", type(value))
if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
if value.scope:
raise RuntimeError("value.scope should be empty")
keywords.append(ast.keyword(arg=arg, value=ast.Constant(value=value.value, kind=None)))
elif value.type == ValueType.NamingValue:
if value.scope:
keywords.append(ast.keyword(arg=arg,
value=ast.Attribute(ast.Name(value.scope, ast.Load()), value.value,
ast.Store())))
else:
keywords.append(ast.keyword(arg=arg, value=ast.Name(value.value, ast.Store())))
else:
raise RuntimeError("Please handle custom-object first")
result = ast.Call(func=ast_func, args=ast_args, keywords=keywords)
ast.fix_missing_locations(result)
return result
@staticmethod
def update_arg_value(src_argument: ScopedValue, dst_ast: ast.AST):
"""
Update 'arg_value' by 'input_argument'
Args:
src_argument (ScopedValue): An instance of ScopedValue represents new argument.
dst_ast (ast.AST): Targets of ast.Assign.
Raises:
TypeError: Input src_argument is not a ScopedValue
RuntimeError: If 'dst_ast' is an instance of ast.Constant but type of 'src_argument' is not
ValueType.IntValue, ValueType.FloatValue or ValueType.StringValue.
RuntimeError: If 'dst_ast' is an instance of ast.Name or ast.Attribute but type of 'src_argument' is not
ValueType.NamingValue.
RuntimeError: When 'dst_ast' is an instance of ast.Name, scope of 'src_argument' is not empty.
RuntimeError: When 'dst_ast' is an instance of ast.Attribute, value of 'dst_ast' is not an instance of
ast.Name.
RuntimeError: If 'dst_ast' is an instance of ast.Tuple but type of 'src_argument' is not
ValueType.TupleValue.
RuntimeError: If 'dst_ast' is an instance of ast.Constant, ast.Name, ast.Attribute or ast.Tuple.
RuntimeError: When 'dst_ast' is an instance of ast.Tuple, length of elts of 'dst_ast' is not equal to length
of value of 'src_argument'.
"""
if not isinstance(src_argument, ScopedValue):
raise TypeError("src_argument should be ScopedValue, got: ", type(src_argument))
if isinstance(dst_ast, ast.Constant):
if src_argument.type not in [ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue]:
raise RuntimeError("src_argument should be a IntValue, FloatValue or StringValue, got:",
str(src_argument.type))
dst_ast.value = src_argument.value
return
if isinstance(dst_ast, ast.Name):
if src_argument.type != ValueType.NamingValue:
raise RuntimeError("src_argument.type should equal to ValueType.NamingValue")
if src_argument.scope:
raise RuntimeError("src_argument.scope should be empty")
dst_ast.id = src_argument.value
return
if isinstance(dst_ast, ast.Attribute):
if src_argument.type != ValueType.NamingValue:
raise RuntimeError("src_argument.type should equal to ValueType.NamingValue")
attr_value = dst_ast.value
if not isinstance(attr_value, ast.Name):
raise RuntimeError("Only support ast.Name as value of attribute ", type(attr_value))
attr_value.id = src_argument.scope
dst_ast.attr = src_argument.value
return
if isinstance(dst_ast, ast.Tuple):
if src_argument.type != ValueType.TupleValue:
raise RuntimeError("src_argument.type should equal to ValueType.TupleValue")
if len(src_argument.value) != len(dst_ast.elts):
raise RuntimeError("src_argument.value and elts in ast should have same length")
for elt_index, elt in enumerate(dst_ast.elts):
AstModifier.update_arg_value(src_argument.value[elt_index], elt)
return
raise RuntimeError("keyword value type is not supported", type(dst_ast))

View File

@ -0,0 +1,16 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Transformers for optimizing ast."""
from .flatten_recursive_stmt import FlattenRecursiveStmt

View File

@ -0,0 +1,154 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Ast optimizer for flatten recursive call."""
from typing import Any, Tuple
import ast
from ast import FunctionDef
from mindspore import log as logger
class FlattenRecursiveStmt(ast.NodeTransformer):
"""Ast optimizer for flatten recursive call."""
def __init__(self):
"""
Constructor of FlattenRecursiveStmt.
Returns:
An instance of ast optimizer for flatten recursive call.
"""
self._flatten_table: dict = {
ast.Return: ["value"],
ast.Call: ["args"],
ast.BinOp: ["left", "right"],
ast.BoolOp: ["values"],
ast.unaryop: ["operand"],
}
@staticmethod
def _generate_target_name(node: ast.AST, target_names):
"""Generate unique target name."""
if isinstance(node, ast.Call):
func = node.func
if isinstance(func, ast.Name):
target_name = func.id
elif isinstance(func, ast.Attribute):
target_name = func.attr
else:
logger.warning("unhandled type of func of ast.Call while generating new target name: %s ", type(func))
target_name = "function"
elif isinstance(node, ast.Return):
target_name = "return_value"
elif isinstance(node, (ast.BinOp, ast.boolop, ast.UnaryOp)):
target_name = type(node.op).__name__
else:
logger.warning("unhandled type of node while generating new target name: %s ", type(node))
target_name = type(node).__name__
suffix = 0
result = target_name
while result in target_names:
suffix += 1
result = f"{target_name}_{suffix}"
target_names.append(result)
return result
@staticmethod
def _fill_in_original_target_names(target_names, node):
"""Fill in original target names before getting unique names."""
for function_index in range(len(node.body)):
child = node.body[function_index]
if not isinstance(child, ast.Assign):
continue
targets = child.targets
for target in targets:
if not isinstance(target, ast.Name):
raise RuntimeError("currently only support ast.Name targets")
target_name = target.id
if target_name not in target_names:
target_names.append(target_name)
@staticmethod
def _create_new_assign_node(node: ast.AST, target_names) -> Tuple[str, ast.AST]:
"""Create new assign node to be inserted into ast.FunctionDef."""
if isinstance(node, (ast.Name, ast.Constant, ast.Num, ast.Str, ast.NameConstant, ast.Bytes, ast.Ellipsis)):
return "", node
new_target_name = FlattenRecursiveStmt._generate_target_name(node, target_names)
return new_target_name, ast.Assign(targets=[ast.Name(id=new_target_name, ctx=ast.Store())], value=node)
def _flatten_statement(self, node: ast.AST, target_names) -> [ast.AST]:
"""Flatten recursive statement according to different node type."""
flatten_config = self._flatten_table.get(type(node))
if flatten_config is None:
return []
results = []
for todo_name in flatten_config:
todos = getattr(node, todo_name)
if isinstance(todos, list):
new_list = []
for todo in todos:
new_target_name, new_node = FlattenRecursiveStmt._create_new_assign_node(todo, target_names)
if id(new_node) == id(todo):
new_list.append(todo)
else:
new_list.append(ast.Name(id=new_target_name, ctx=ast.Load()))
results.append(new_node)
setattr(node, todo_name, new_list)
elif isinstance(todos, dict):
new_dict = []
for key, value in todos:
new_target_name, new_node = FlattenRecursiveStmt._create_new_assign_node(value, target_names)
if id(new_node) == id(value):
new_dict[key] = value
else:
new_dict[key] = ast.Name(id=new_target_name, ctx=ast.Load())
results.append(new_node)
setattr(node, todo_name, new_dict)
else:
new_target_name, new_node = FlattenRecursiveStmt._create_new_assign_node(todos, target_names)
if id(new_node) != id(todos):
setattr(node, todo_name, ast.Name(id=new_target_name, ctx=ast.Load()))
results.append(new_node)
return results
def visit_FunctionDef(self, node: FunctionDef) -> Any:
"""Traverse construct node and flatten recursive nodes."""
if node.name != "construct":
return node
target_names = []
self._fill_in_original_target_names(target_names, node)
index = len(node.body) - 1
while index >= 0:
child = node.body[index]
if isinstance(child, ast.Assign):
stmt = child.value
elif isinstance(child, ast.Expr):
stmt = child.value
else:
stmt = child
results = self._flatten_statement(stmt, target_names)
if results:
results.reverse()
for result in results:
node.body.insert(index, result)
index += 1
index -= 1
return node
def transform(self, ast_root):
"""Interface of FlattenRecursiveStmt."""
ast_root = self.visit(ast_root)
ast_root = ast.fix_missing_locations(ast_root)
return ast_root

View File

@ -0,0 +1,159 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Unique name producer for target, name of node."""
from typing import Union
from .node import Node
from .api.node_type import NodeType
class Namer:
"""
Used for unique identity in a class-scope. current used for target of construct-function.
Namer records times of name been used, and add prefix to origin name for unique name. For example, when a Namer
record "name1" has been used 10 times, when a new request require a unique name base on 'name1', namer will respond
"name1_10" as unique name.
"""
def __init__(self):
"""Constructor of Namer."""
self._names: {str: int} = {}
@staticmethod
def _real_name(name: str) -> str:
"""
Find real name. For example, "name1" is the real name of "name1_10", "name1" is the real name of "name1_10_3".
If not find real name before find unique name, unique name may be not unique. For example:
1. "name1" has been used 10 times, which means "name1", "name1_2", "name1_3" ... "name1_10" has been used;
2. new request require a unique name base on 'name1_5'
3. If namer not find real name of "name1_5", namer will find that "name1_5" is never used and respond
"name1_5" as unique name which is used before, actually.
Args:
name (str): Origin name which may have digit prefix.
Returns:
A string represents real-name.
"""
pos = name.rfind("_")
if pos == -1:
return name
digit = True
for i in range(pos + 1, len(name)):
if not name[i].isdigit():
digit = False
break
if digit:
return Namer._real_name(name[:pos])
return name
def get_name(self, origin_name: str) -> str:
"""
Get unique name from 'origin_name'.
Args:
origin_name (str): Origin name which may be duplicated.
Returns:
A string represents unique-name.
"""
origin_name = Namer._real_name(origin_name)
number = self._names.get(origin_name)
if number is None:
self._names[origin_name] = 1
return origin_name
self._names[origin_name] = number + 1
return f"{origin_name}_{number}"
def add_name(self, name: str):
"""
Add a name to Namer which should be unique.
Args:
name (str): A name should be unique in current namer.
Raises:
RuntimeError: If name is not unique in current namer.
"""
real_name = Namer._real_name(name)
number = self._names.get(real_name)
if number is not None:
raise RuntimeError("name duplicated: ", name)
self._names[name] = 1
class TargetNamer(Namer):
"""
Used for unique-ing targets of node.
"""
def get_real_arg(self, origin_arg: str) -> str:
"""
Get real argument from 'origin_arg' because target of node which produces 'origin_arg' may be change for unique.
Args:
origin_arg (str): Origin argument string which may be undefined cause to target unique-lize.
Returns:
A string represents real argument name.
"""
num = self._names.get(origin_arg)
if num is None or num == 1:
return origin_arg
return f"{origin_arg}_{num - 1}"
class NodeNamer(Namer):
"""
Used for unique-ing node-name which is also used as field of init-function and key of global_vars
"""
def get_name(self, node_or_name: Union[Node, str]) -> str:
"""
Override get_name in Namer class.
Get unique node_name from 'origin_name' or an instance of node.
Args:
node_or_name (Union[Node, str]): A string represents candidate node_name or an instance of node who require
A unique node_name.
Returns:
A string represents unique node_name.
"""
if isinstance(node_or_name, Node):
origin_name = node_or_name.get_name()
if origin_name is None or not origin_name:
if node_or_name.get_node_type() == NodeType.CallCell:
if not isinstance(node_or_name, Node):
raise TypeError("node_or_name should be Node, got: ", type(node_or_name))
targets = node_or_name.get_targets()
# return node and head node will not call this method
if not targets:
raise RuntimeError("node should has at lease one target except return-node and head-node: ",
node_or_name)
origin_name = str(targets[0].value)
elif node_or_name.get_node_type() == NodeType.Python:
origin_name = node_or_name.get_instance().__name__
elif node_or_name.get_node_type() == NodeType.Input:
origin_name = "parameter"
else:
raise RuntimeError("Node type unsupported:", node_or_name.get_node_type())
elif isinstance(node_or_name, str):
if not node_or_name:
raise RuntimeError("input node_name is empty.")
origin_name = node_or_name
else:
raise RuntimeError("unexpected type of node_or_name: ", type(node_or_name))
return super(NodeNamer, self).get_name(origin_name)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,45 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Base class of parser."""
import abc
import ast
from .symbol_tree import SymbolTree
class Parser(abc.ABC):
"""
DFS into a ast_node until add node into SymbolTree
"""
def target(self) -> type:
"""
Get type of ast which could be accepted by current parser.
Returns:
A type of ast.
"""
return type(None)
@abc.abstractmethod
def process(self, stree: SymbolTree, node: ast.AST):
"""
Parse input ast node and add parse result into SymbolTree.
Args:
stree (SymbolTree): current symbol_tree
node (ast.AST): node who is tried to be parsed
"""
raise NotImplementedError

View File

@ -0,0 +1,86 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parser register."""
from typing import Optional
from .parser import Parser
class ParserRegister:
"""Parser register."""
def __init__(self):
self._parsers: dict = {}
@classmethod
def instance(cls) -> 'ParserRegister':
"""
Get singleton of ParserRegister.
Returns:
An instance of ParserRegister.
"""
if not hasattr(ParserRegister, "_instance"):
ParserRegister._instance = ParserRegister()
return ParserRegister._instance
@staticmethod
def reg_parser(parser: Parser):
"""
Register a 'parser' to current ParserRegister.
Args:
parser (Parser): An instance of Parser to be registered.
"""
if isinstance(parser, Parser):
ParserRegister.instance()._parsers[parser.target()] = parser
def get_parser(self, ast_type: type) -> Optional[Parser]:
"""
Get parser from current ParserRegister by type of ast.
Args:
ast_type (type): An type of ast which want to be parsed.
Returns:
An instance of Parser if there existing suitable parser in current ParserRegister else None.
"""
return self._parsers.get(ast_type)
def get_parsers(self) -> [Parser]:
"""
Get all parsers registered in current ParserRegister.
Returns:
An list of instances of Parser.
"""
return self._parsers
class ParserRegistry:
"""Parser registry."""
def __init__(self, parser: Parser):
ParserRegister.instance().reg_parser(parser)
def reg_parser(parser: Parser):
"""
A global method for registering parser into ParserRegister singleton.
Args:
parser (Parser): An instance of Parser to be registered.
"""
return ParserRegistry(parser)

View File

@ -0,0 +1,17 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Parsers for resolve ast to SymbolTree
"""

View File

@ -0,0 +1,61 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse ast.arguments to input-node of SymbolTree."""
import ast
from ..parser import Parser
from ..parser_register import reg_parser
from ..symbol_tree import SymbolTree
class ArgumentsParser(Parser):
"""Parse ast.arguments to input-node of SymbolTree."""
def target(self):
"""Parse target type"""
return ast.arguments
def process(self, stree: SymbolTree, node: ast.arguments):
"""
Parse ast.arguments and create input-node to stree.
Args:
stree (SymbolTree): symbol tree under parsing.
node (ast.arguments): argument node in construct.
Raises:
RuntimeError: Types of node.args elements are not ast.arg.
"""
if hasattr(node, "posonlyargs"):
stree.try_append_python_node(node, node.posonlyargs)
for arg in node.args:
if not isinstance(arg, ast.arg):
raise RuntimeError("Unsupported ast type in arguments arg: ", arg)
stree.append_input_node(arg.arg)
if hasattr(node, "vararg"):
stree.try_append_python_node(node, node.vararg)
if hasattr(node, "kwonlyargs"):
stree.try_append_python_node(node, node.kwonlyargs)
if hasattr(node, "kw_defaults"):
stree.try_append_python_node(node, node.kw_defaults)
if hasattr(node, "kwarg"):
stree.try_append_python_node(node, node.kwarg)
if hasattr(node, "defaults"):
stree.try_append_python_node(node, node.defaults)
g_arguments_parser = reg_parser(ArgumentsParser())

View File

@ -0,0 +1,252 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse ast.Assign in construct function to node of SymbolTree."""
import ast
import astunparse
from mindspore import log as logger
from ..symbol_tree import SymbolTree
from ..node import Node, TreeNode
from ..parser import Parser
from ..parser_register import reg_parser
from ..api.scoped_value import ScopedValue
from ..symbol_tree_builder import SymbolTreeBuilder
class AssignParser(Parser):
"""Parse ast.Assign in construct function to node of SymbolTree."""
def target(self):
"""Parse target type."""
return ast.Assign
@staticmethod
def _create_scopedvalue_from_tuple_ast(node: ast.Tuple) -> ScopedValue:
"""
Create ScopedValue from a tuple ast node.
Args:
node (ast.Tuple): A tuple node.
Returns:
An instance of ScopedValue.
Raises:
RuntimeError: Only support ast.Constant as elts of ast.Tuple.
"""
tuple_elts = node.elts
tuple_values = []
for tuple_elt in tuple_elts:
if not isinstance(tuple_elt, ast.Constant):
raise RuntimeError("Only support ast.Constant as elts of ast.Tuple.")
tuple_values.append(tuple_elt.value)
return ScopedValue.create_variable_value(tuple(tuple_values))
@staticmethod
def _create_scopedvalue(node: ast.expr) -> ScopedValue:
"""
Create ScopedValue from an ast node.
Args:
node (ast.expr): An ast node.
Returns:
An instance of ScopedValue.
Raises:
RuntimeError: Value of target of ast.Assign should be an ast.Name when target is an ast.Attribute.
RuntimeError: Type of input node is unsupported.
"""
if isinstance(node, ast.Name):
return ScopedValue.create_naming_value(node.id)
if isinstance(node, ast.Attribute):
scope = node.value
if not isinstance(scope, ast.Name):
raise RuntimeError("value of target of ast.Assign should be a ast.Name when target is a ast.Attribute.")
return ScopedValue.create_naming_value(node.attr, scope.id)
if isinstance(node, ast.Tuple):
return AssignParser._create_scopedvalue_from_tuple_ast(node)
if isinstance(node, ast.Constant):
return ScopedValue.create_variable_value(node.value)
raise RuntimeError("Unsupported ast type to argument:", node)
@staticmethod
def _get_func_name(ast_node: ast.Call) -> str:
"""
Get the func name from ast.Call.
Args:
ast_node (ast.Call): Input ast.Call node.
Returns:
Func name.
Raises:
RuntimeError: Func of input ast node is not ast.Name or ast.Attribute.
"""
func = ast_node.func
if isinstance(func, ast.Name):
return func.id
if isinstance(func, ast.Attribute):
return func.attr
raise RuntimeError("FuncValue is should be Name or a Attribute:", astunparse.unparse(func))
@staticmethod
def _get_func_scope(ast_node: ast.Call) -> str:
"""
Get the func scope from ast.Call.
Args:
ast_node (ast.Call): Input ast.Call node.
Returns:
Func scope.
Raises:
RuntimeError: FuncValue is not an ast.Name when func is an ast.Attribute.
RuntimeError: Func of input ast node is not ast.Name or ast.Attribute.
"""
func = ast_node.func
if isinstance(func, ast.Name):
return ""
if isinstance(func, ast.Attribute):
value = func.value
if not isinstance(value, ast.Name):
raise RuntimeError("FuncValue is should be Name:", ast.dump(func))
return value.id
raise RuntimeError("FuncValue is should be Name or a Attribute:", ast.dump(func))
@staticmethod
def _get_symbol_object(symbol_name, origin_net):
"""
Get the func scope from ast.Call.
Args:
symbol_name (str): Func name.
origin_net ([nn.Cell]): Network instance.
Returns:
Symbol Object.
"""
var_dict = origin_net.__dict__
for key, value in var_dict["_cells"].items():
if key == symbol_name:
return value
for key, value in var_dict["_primitives"].items():
if key == symbol_name:
return value
return None
@staticmethod
def _create_kwargs(keywords: [ast.keyword]) -> {str, ScopedValue}:
"""
Transfer ast.Call keywords to a dict of ScopedValue when creating a symbol tree node.
Args:
keywords ([ast.keyword]): Keywords of ast.Call node.
Returns:
A dict of ScopedValue.
"""
results = {}
for keyword in keywords:
results[keyword.arg] = AssignParser._create_scopedvalue(keyword.value)
return results
@staticmethod
def _convert_ast_call_to_node(ast_node: ast.Call, father_ast_node: ast.Assign, stree: SymbolTree) -> Node:
"""
Convert ast.Call to a symbol tree node.
Args:
ast_node ([ast.Call]): An ast.Call of assign node in construct.
father_ast_node ([ast.Assign]): Assign node in construct.
stree ([SymbolTree]): Symbol Tree under parsing.
Returns:
An instance of Node in Symbol Tree.
Raises:
RuntimeError: kwargs in construct function assign is unsupported.
"""
target = AssignParser._create_scopedvalue(father_ast_node.targets[0])
func_name = AssignParser._get_func_name(ast_node)
if func_name is None or func_name == "":
raise RuntimeError("function name not exist")
func_scope = AssignParser._get_func_scope(ast_node)
func = ScopedValue.create_naming_value(func_name, func_scope)
call_args = [AssignParser._create_scopedvalue(arg) for arg in ast_node.args]
call_kwargs = AssignParser._create_kwargs(ast_node.keywords)
if ast_node.keywords:
raise RuntimeError("kwargs in construct function assign is unsupported.")
obj = AssignParser._get_symbol_object(func_name, stree.get_origin_network())
# need check if node is a callmethod, like: x = len(x)
# need check if node is a callprimitive, like: x = x * 5
is_sub_tree = False
if is_sub_tree:
stb = SymbolTreeBuilder(obj)
new_stree = stb.build()
return TreeNode(new_stree, father_ast_node, [target], func, call_args, call_kwargs, func_name,
new_stree.get_origin_network())
return Node.create_call_cell(obj, father_ast_node, [target], func, call_args, call_kwargs, func_name)
def process(self, stree: SymbolTree, node: ast.Assign):
"""
Parse ast.Assign and create a node in symbol tree.
Will create node when value of ast.Assign is in [ast.Call, ast.Name, ast.Constant, ast.Attribute].
Will create python node when value of ast.Assign is in
[ast.BinOp, ast.BoolOp, ast.Subscript, ast.List, ast.Tuple, ast.Dict].
Other value types are not supported.
Args:
stree ([SymbolTree]): Symbol Tree under parsing.
node ([ast.Assign]): An ast.Assign node.
Raises:
RuntimeError: Only support one target in assign now.
RuntimeError: Unsupported node type in construct function.
"""
targets = node.targets
if len(targets) != 1:
raise RuntimeError("Only support one target in assign now")
value = node.value
if isinstance(value, ast.Call):
node_ = AssignParser._convert_ast_call_to_node(value, node, stree)
stree.append_origin_field(node_)
elif isinstance(value, (ast.BinOp, ast.BoolOp, ast.Subscript)):
logger.warning(f"ops-call({astunparse.unparse(node)}) in assign will be supported in near feature, "
f"ignored as a python node now")
stree.try_append_python_node(node, node)
elif isinstance(value, (ast.Name, ast.Constant, ast.Attribute, ast.Num, ast.NameConstant, ast.Bytes, ast.Str)):
if isinstance(value, ast.Name):
node_name = "name_assign"
elif isinstance(value, ast.Constant):
node_name = "constant_assign"
else:
node_name = "attribute_assign"
target = AssignParser._create_scopedvalue(node.targets[0])
call_args = [AssignParser._create_scopedvalue(value)]
node_ = Node.create_call_pass_through_method(node, [target], call_args, {}, node_name)
stree.append_origin_field(node_)
elif isinstance(value, (ast.List, ast.Tuple, ast.Dict)):
# add these as callmethod node if necessary
stree.try_append_python_node(node, node)
else:
raise RuntimeError(f"Unsupported statement({astunparse.unparse(node)}) in construct function!")
g_assign_parser = reg_parser(AssignParser())

View File

@ -0,0 +1,170 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse ast.ClassDef which is subclass of Cell to SymbolTree."""
import ast
from mindspore import log as logger
from ..symbol_tree import SymbolTree
from ..parser import Parser
from ..parser_register import ParserRegister, reg_parser
from ..api.scoped_value import ScopedValue
from ..ast_modifier import AstModifier
class ClassDefParser(Parser):
"""Parse ast.ClassDef which is subclass of Cell to SymbolTree."""
def target(self):
"""Parse target type"""
return ast.ClassDef
@staticmethod
def _process_init_func_ast(init_ast: ast.FunctionDef, ori_cls_name: str, opt_cls_name: str):
"""Process init func"""
super_index = ClassDefParser._modify_super_expr_of_init_func(init_ast, ori_cls_name, opt_cls_name)
ClassDefParser._modify_arguments_of_init_func(init_ast)
ClassDefParser._replace_ori_field_of_init_func(init_ast.body, super_index)
ClassDefParser._insert_handler_to_init_func(init_ast, super_index)
@staticmethod
def _modify_super_expr_of_init_func(ast_init_fn: ast.FunctionDef, ori_cls_name: str, opt_cls_name: str) -> int:
"""Modify network name in super(XXnet).__init__()"""
if not ast_init_fn.body:
return -1
super_index = -1
super_call_args = None
while True:
super_index += 1
expr = ast_init_fn.body[super_index]
if not isinstance(expr, ast.Expr):
continue
expr_value = expr.value
if not isinstance(expr_value, ast.Call):
continue
expr_value_func = expr_value.func
if not isinstance(expr_value_func, ast.Attribute):
continue
expr_value_func_value = expr_value_func.value
if expr_value_func.attr != "__init__" or not isinstance(expr_value_func_value, ast.Call):
continue
expr_value_func_value_func = expr_value_func_value.func
if not isinstance(expr_value_func_value_func, ast.Name) or expr_value_func_value_func.id != "super":
continue
super_call_args = expr_value_func_value.args
break
if super_call_args is None or not isinstance(super_call_args, list) or len(super_call_args) != 2:
return super_index
super_call_arg = super_call_args[0]
if super_call_arg.id != ori_cls_name:
raise RuntimeError("super_call_arg.id should equal to ori_cls_name")
super_call_arg.id = opt_cls_name
return super_index
@staticmethod
def _modify_arguments_of_init_func(ast_init_fn: ast.FunctionDef):
"""Replace init function input parameters to self and global_vars."""
arg_self = ast.arg(arg="self", annotation="")
arg_global_vars = ast.arg(arg="global_vars", annotation="")
ast_init_fn.args = ast.arguments(args=[arg_self, arg_global_vars], posonlyargs=[], kwonlyargs=[],
kw_defaults=[], defaults=[], vararg=None, kwarg=None)
ast.fix_missing_locations(ast_init_fn)
@staticmethod
def _replace_ori_field_of_init_func(bodies: [], super_index: int):
"""
Replace original field in init func to self.XX = getattr(self._handler, "XX").
Only keep following two kinds of ast nodes in bodies right now:
1. Ast.If and test is self.XX.
2. Ast.Assign and target is self.XX.
Args:
bodies ([]): bodied of init ast.FunctionDef.
super_index (int): index of super().__init__() in bodies.
Raises:
RuntimeError: Not support multi-targets in assign.
RuntimeError: Only support target.value in [ast.Name] in assign node.
"""
body_index_to_be_deleted = []
for body_index, body in enumerate(bodies):
if body_index == super_index:
continue # ignoring super.__init__()
if isinstance(body, ast.If) and isinstance(body.test, ast.Attribute) \
and isinstance(body.test.value, ast.Name) and body.test.value.id == 'self':
ClassDefParser._replace_ori_field_of_init_func(body.body, -1)
ClassDefParser._replace_ori_field_of_init_func(body.orelse, -1)
continue
if not isinstance(body, ast.Assign): # if not assign node, delete
body_index_to_be_deleted.append(body_index)
continue
if len(body.targets) != 1:
raise RuntimeError("Not support multi-targets in assign now!")
target = body.targets[0]
if not isinstance(target, ast.Attribute): # only keep class member
body_index_to_be_deleted.append(body_index)
continue
if not isinstance(target.value, ast.Name):
raise RuntimeError("Only support target.value in ast.Name now!")
target_value: ast.Name = target.value
if target_value.id != "self":
body_index_to_be_deleted.append(body_index)
continue
field_name = target.attr
body.value = ast.Call(ast.Name('getattr', ast.Load()),
[ast.Attribute(ast.Name('self', ast.Load()), '_handler', ast.Load()),
ast.Constant(value=field_name, kind=None)], [])
for counter, index in enumerate(body_index_to_be_deleted):
bodies.pop(index - counter)
@staticmethod
def _insert_handler_to_init_func(ast_init_fn: ast.FunctionDef, super_index):
"""Insert 'self._handler = global_vars.get('handler')' to init ast.FunctionDef.body"""
if super_index == -1:
super_index = 0
AstModifier.insert_assign_to_function(ast_init_fn, [ScopedValue.create_naming_value("_handler", "self")],
ScopedValue.create_naming_value("get", "global_vars"),
[ScopedValue.create_variable_value("handler")], None,
ast_init_fn.body[super_index], False)
def process(self, stree: SymbolTree, node: ast.ClassDef):
"""
Parse init and construct in ast.ClassDef.
Args:
stree ([SymbolTree]): Symbol Tree under parsing.
node ([ast.ClassDef]): An ast.ClassDef node.
"""
# change class name
node.name = stree.get_opt_cls_name()
stree.set_class_ast(node)
for body in node.body:
if isinstance(body, ast.FunctionDef):
if body.name == "__init__":
ClassDefParser._process_init_func_ast(body, stree.get_ori_cls_name(), stree.get_opt_cls_name())
stree.set_init_func_ast(body)
elif body.name == "construct":
parser: Parser = ParserRegister.instance().get_parser(ast.FunctionDef)
parser.process(stree, body)
else:
logger.warning(
"Ignoring ast.FunctionDef in ast.ClassDef except __init__ and construct function: %s",
body.name)
else:
logger.warning("Ignoring unsupported node(%s) in ast.ClassDef.", type(body).__name__)
g_classdef_parser = reg_parser(ClassDefParser())

View File

@ -0,0 +1,53 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse bodies of ast.FunctionDef which is construct function to nodes of SymbolTree."""
import ast
from ..parser_register import ParserRegister, reg_parser
from ..parser import Parser
from ..symbol_tree import SymbolTree
class FunctionDefParser(Parser):
"""Parse bodies of ast.FunctionDef which is construct function to nodes of SymbolTree."""
def target(self):
"""Parse target type"""
return ast.FunctionDef
def process(self, stree: SymbolTree, node: ast.FunctionDef):
"""Parse bodies of ast.FunctionDef which is construct function to nodes of SymbolTree."""
stree.set_ast_root(node)
# parse args as inputs of stree
arguments: ast.arguments = node.args
parser: Parser = ParserRegister.instance().get_parser(ast.arguments)
parser.process(stree, arguments)
# parse body as node of stree
for body in node.body:
# avoid add dead code, so we need to break if return is added.
parser: Parser = ParserRegister.instance().get_parser(type(body))
if parser is None:
stree.append_python_node(node, body)
else:
parser.process(stree, body)
if hasattr(node, "decorator_list"):
stree.try_append_python_node(node, node.decorator_list)
if hasattr(node, "returns"):
stree.try_append_python_node(node, node.returns)
g_functiondef_parser = reg_parser(FunctionDefParser())

View File

@ -0,0 +1,124 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse ast.Module to SymbolTrees."""
from typing import Any
import os
import ast
import copy
import inspect
import astunparse
from mindspore import log as logger
from ..symbol_tree import SymbolTree
from ..parser import Parser
from ..parser_register import ParserRegister, reg_parser
class ClassFinder(ast.NodeVisitor):
"""Find all ast.ClassDef in input ast node."""
def __init__(self):
"""Keep all found ast.ClassDef in self._classes"""
self._classes: [ast.ClassDef] = []
def visit_ClassDef(self, node: ast.ClassDef) -> Any:
"""Iterate over all nodes and save ast.ClassDef nodes."""
self._classes.append(node)
def find_all_classes(self, node: ast.AST) -> [ast.ClassDef]:
"""Interface of ClassFinder."""
self.visit(node)
return self._classes
class ModuleParser(Parser):
"""Parse ast.Module to SymbolTrees."""
def target(self):
"""Parse target type"""
return ast.Module
@staticmethod
def _find_class(ast_node: ast.Module) -> ast.ClassDef:
"""Find all ast.ClassDef in ast.Module, only support one ast.ClassDef in ast.Module now."""
visitor = ClassFinder()
classes = visitor.find_all_classes(ast_node)
if not classes:
raise RuntimeError("No class in module")
if len(classes) > 1:
raise RuntimeError("Multi-class in module is not supported now")
return classes[0]
@staticmethod
def get_import_node(ast_root):
"""Iterate over ast_root and return all ast.Import nodes or ast.ImportFrom nodes in ast_root."""
import_nodes = []
class GetImportNode(ast.NodeVisitor):
"""Find all import nodes from input ast node."""
def visit_Import(self, node: ast.Import) -> Any:
"""Iterate over all nodes and save ast.Import nodes."""
import_nodes.append(copy.deepcopy(node))
return node
def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
"""Iterate over all nodes and save ast.ImportFrom nodes."""
import_nodes.append(copy.deepcopy(node))
return node
def get_node(self, input_ast):
"""Interface of GetImportNode."""
self.generic_visit(input_ast)
return True
get_node_handler = GetImportNode()
get_node_handler.get_node(ast_root)
return import_nodes
@staticmethod
def _add_import_to_module(module: ast.Module, origin_net):
"""Insert two groups of import nodes to ast.Module, common ones and those from class definition file."""
module.body.insert(0, ast.Import([ast.alias(name='mindspore', asname=None)]))
module.body.insert(1, ast.ImportFrom(module='mindspore', names=[ast.alias(name='nn', asname=None)], level=0))
module.body.insert(2, ast.ImportFrom(module='mindspore.nn', names=[ast.alias(name='Cell', asname=None)],
level=0))
origin_net_source_code_file = inspect.getfile(type(origin_net))
if not os.path.exists(origin_net_source_code_file):
raise RuntimeError("File ", origin_net_source_code_file, " not exist")
try:
with open(origin_net_source_code_file, "r") as f:
source_code = f.read()
import_nodes = ModuleParser.get_import_node(ast.parse(source_code))
except RuntimeError:
raise RuntimeError("get import nodes error")
if import_nodes:
for import_index, import_node in enumerate(import_nodes):
module.body.insert(import_index + 3, import_node)
ast.fix_missing_locations(module)
def process(self, stree: SymbolTree, node: ast.Module):
"""Process ast.ClassDef nodes in ast.Module."""
ModuleParser._add_import_to_module(node, stree.get_origin_network())
class_ast = ModuleParser._find_class(node)
stree.set_class_ast(class_ast)
for body in node.body:
if isinstance(body, ast.ClassDef):
parser: Parser = ParserRegister.instance().get_parser(ast.ClassDef)
parser.process(stree, body)
else:
logger.info(f"Ignoring unsupported node({astunparse.unparse(body)}) in ast.Module.")
g_module_parser = reg_parser(ModuleParser())

View File

@ -0,0 +1,40 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse ast.Return output-node of SymbolTree."""
import ast
from ..symbol_tree import SymbolTree
from ..node import Node
from ..parser import Parser
from ..parser_register import reg_parser
class ReturnParser(Parser):
"""Parse ast.Return output-node of SymbolTree."""
def target(self):
"""Parse target type"""
return ast.Return
def process(self, stree: SymbolTree, node: ast.Return):
"""Parse ast.Return to output-node of SymbolTree."""
return_value = node.value
if not isinstance(return_value, ast.Name):
raise RuntimeError("Only ast.Name as return value")
node_return = Node.create_output_node(node, [return_value.id])
stree.append_origin_field(node_return)
g_return_parser = reg_parser(ReturnParser())

View File

@ -0,0 +1,899 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SymbolTree class define of Rewrite according to forward function of a network."""
from typing import Optional, Union, Tuple
import os
import sys
import ast
import tempfile
import astunparse
from mindspore.nn import Cell
from mindspore import log as logger
from .node import Node, TreeNode
from .api.node_type import NodeType
from .ast_modifier import AstModifier
from .api.scoped_value import ScopedValue, ValueType
from .symbol_tree_dumper import SymbolTreeDumper
from .topological_manager import TopoManager
from .namer import TargetNamer, NodeNamer
class Position:
"""Position indicates a source code position in one network."""
def __init__(self, symbol_tree, node, before_node: bool):
"""
Constructor of Position.
Recommend to use class method of position rather than constructor of Position.
Args:
symbol_tree (SymbolTree): A handler of SymbolTree indicated position in which SymbolTree.
node (Node): A handler of Node indicated position is around which Node.
before_node (bool): A bool indicated position is before or after the 'node'.
"""
self.symbol_tree = symbol_tree
self.node = node
self.before_node = before_node
@classmethod
def create(cls, symbol_tree, node, before_node):
"""
Class method of Position. Return None when symbol_tree or node is None.
Args:
symbol_tree: A handler of SymbolTree indicated position in which SymbolTree.
node: A handler of Node indicated position is around which Node.
before_node (bool): A bool indicated position is before or after the 'node'.
Returns:
A Position.
"""
if symbol_tree is None or node is None:
return None
return Position(symbol_tree, node, before_node)
class SymbolTree:
"""A symbol-tree usually corresponding to forward method of a network."""
def __init__(self, origin_network: Cell, module_ast: ast.Module):
"""
Constructor of SymbolTree. Rewrite recommend using SymbolTreeBuilder to instantiate an instance of SymbolTree
rather than invoking constructor of SymbolTree directly.
Args:
origin_network (Cell): A handler to original network instance.
module_ast (ast.Module): An instance of ast.AST represents ast node of original network.
"""
origin_network_key = "handler"
# init unique-namers
self._target_namer = TargetNamer()
self._node_name_namer = NodeNamer()
# name or node would use as name of field, so name of origin network handler field should be added into \
# _node_name_namer.
self._node_name_namer.add_name(origin_network_key)
self._topo_mgr = TopoManager()
self._global_vars: {str, object} = {origin_network_key: origin_network}
self._nodes: {str, Node} = {}
# parameters of forward method
self._inputs: [Node] = []
self._ori_cls_name = type(origin_network).__name__
self._opt_cls_name = self._ori_cls_name + "Opt"
self._origin_network = origin_network
self._module_ast: ast.Module = module_ast
self._class_ast: Optional[ast.ClassDef] = None
self._root_ast: Optional[ast.FunctionDef] = None
self._init_func_ast: Optional[ast.FunctionDef] = None
# head node is always point to the first node(in source code order) of SymbolTree
self._head = None
# tail node is always point to the last node(in source code order) of SymbolTree
self._tail = None
self._return: Optional[Node] = None
def get_ori_cls_name(self) -> str:
"""
Get class name of original network.
Returns:
A str represents class name of original network.
"""
return self._ori_cls_name
def get_opt_cls_name(self) -> str:
"""
Get class name of rewritten network.
Returns:
A str represents class name of rewritten network.
"""
return self._opt_cls_name
def get_module_ast(self):
"""
Getter of `_module_ast`.
Returns:
An instance of ast.AST represents ast node of corresponding module.
"""
return self._module_ast
def get_ast_root(self):
"""
Getter of `_root_ast`.
Returns:
An instance of ast.AST represents ast node of corresponding forward method.
"""
return self._root_ast
def set_class_ast(self, ast_node: ast.ClassDef):
"""
Setter of `_class_ast`.
Args:
ast_node (ast.ClassDef): An instance of ast.ClassDef represents ast node of corresponding network class.
"""
self._class_ast = ast_node
def set_init_func_ast(self, ast_node: ast.FunctionDef):
"""
Setter of `_init_func_ast`.
Args:
ast_node (ast.FunctionDef): An instance of ast.FunctionDef represents ast node of init method of
corresponding network class.
"""
self._init_func_ast = ast_node
def set_ast_root(self, ast_node: ast.FunctionDef):
"""
Setter of `_root_ast`.
Args:
ast_node (ast.FunctionDef): An instance of ast.FunctionDef represents ast node of forward method of
corresponding network class.
"""
self._root_ast = ast_node
def get_inputs(self):
"""
Getter of `_inputs` which represents parameters of current forward method.
Returns:
A list of instance of Node whose node_type is NodeType.Input as input nodes.
"""
return self._inputs
def get_head_node(self):
"""
Getter of `_head` which represents the beginning node while iterating SymbolTree nodes.
Returns:
An instance of node.
"""
return self._head
def get_return_node(self):
"""
Getter of `_return` which represents return statement of forward method of network.
Returns:
An instance of node.
"""
return self._return
def get_origin_network(self):
"""
Getter of `_origin_network`.
Returns:
An instance of Cell which represents original network.
"""
return self._origin_network
def nodes(self, unfold_subtree=True):
"""
Getter of nodes if current SymbolTree.
Args:
unfold_subtree (bool): Need to iterate into sub-symbol-tree recursively.
Returns:
A list of instance of Nodes.
"""
if unfold_subtree:
nodes = []
for _, v in self._nodes.items():
if isinstance(v, TreeNode):
nodes.extend(self.nodes(v.symbol_tree))
else:
nodes.append(v)
return nodes
return self._nodes.values()
def get_node(self, node_name: str) -> Optional[Node]:
"""
Get node of current symbol_tree by `node_name`.
Args:
node_name (str): A str represents name of node as key of query.
Returns:
An instance of Node if found else None.
"""
return self._nodes.get(node_name)
def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
if isinstance(node_or_name, Node):
return self.get_node(node_or_name.get_name())
if isinstance(node_or_name, str):
return self.get_node(node_or_name)
return None
def get_node_inputs(self, node_or_name: Union[Node, str]) -> [Node]:
"""
Getter of inputs in topological relation of current 'node_or_name'.
Args:
node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
Returns:
A list of instances of Node as input nodes if 'node_or_name' belong to current SymbolTree. An empty list if
'node_or_name' not belong to current SymbolTree.
"""
real_node: Optional[Node] = self._get_real_node(node_or_name)
if real_node is None:
logger.info("Node(%s) is not belong to current SymbolTree", node_or_name)
return []
return node_or_name.get_inputs()
def get_node_users(self, node_or_name: Union[Node, str]) -> [Tuple[Node, int]]:
"""
Getter of outputs in topological relation of current 'node_or_name'.
Args:
node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
Returns:
A list of instances of Node as output nodes if 'node_or_name' belong to current SymbolTree. An empty list if
'node_or_name' not belong to current SymbolTree.
"""
real_node: Optional[Node] = self._get_real_node(node_or_name)
if real_node is None:
logger.info("Node(%s) is not belong to current SymbolTree", node_or_name)
return []
return self._topo_mgr.get_node_users(node_or_name)
def before(self, node_or_name: Union[Node, str]) -> Position:
"""
Get insert position before 'node_or_name' in source code list.
Consider using symbol_tree, node and before/after as position for sub-tree feature.
Note:
Topological order is not determined here which is determined by arguments of node and updated by
TopologicalManager automatically.
Args:
node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
Returns:
A Position represents an insert point.
Raises:
AssertError: If 'node_or_name' is not a Node or a str
RuntimeError: If 'node_or_name' is not belong to this SymbolTree or any sub-SymbolTree of current
SymbolTree.
"""
node = self._get_real_node(node_or_name)
if node is None:
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
return Position.create(node.get_belong_symbol_tree(), node, True)
def after(self, node_or_name: Union[Node, str]) -> Position:
"""
Get insert position after 'node_or_name' in source code list.
Consider using symbol_tree, node and before/after as position for sub-tree feature.
Note:
Topological order is not determined here which is determined by arguments of node and updated by
TopologicalManager automatically.
Args:
node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
Returns:
A Position represents an insert point.
Raises:
AssertError: If 'node_or_name' is not a Node or a str
RuntimeError: If 'node_or_name' is not belong to this SymbolTree or any sub-SymbolTree of current
SymbolTree.
"""
node = self._get_real_node(node_or_name)
if node is None:
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
return Position.create(node.get_belong_symbol_tree(), node, False)
def insert_node(self, position: Position, node: Node, insert_to_ast: bool = True) -> Node:
"""
Insert a node into SymbolTree.
Note:
Name of node will be unique while inserting node into SymbolTree.
ValueType.CustomObjValue type arguments will be converted to ValueType.NamingValue and custom object will
be saved in global_vars dict while inserting node into SymbolTree.
Targets of node will be unique while inserting node into SymbolTree.
A field instantiation statement will be added into "init" function of network class using node name as field
name when `insert_to_ast` is True while inserting node into SymbolTree.
An assign statement represents invoking to this node will be added into forward function of network class
corresponding to field-instantiation-statement when `insert_to_ast` is True while inserting node into
SymbolTree.
Topological relation is updated and inputs of corresponding node is updated.
Args:
position (Position): A Position indicates an insert position point.
node (Node): An instance of node to be inserted in.
insert_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
True.
Returns:
An instance of node which has been inserted into SymbolTree.
Raises:
RuntimeError: If 'position' is not in current SymbolTree.
RuntimeError: If corresponding ast node is not an ast.Assign when 'insert_to_ast' is True.
"""
# if position in current SymbolTree
if position is not None and position.symbol_tree is not self:
raise RuntimeError("Position is not in current SymbolTree:", position)
# unique targets, name while insert node into symbol_tree
node_name = self._node_name_namer.get_name(node)
node.set_name(node_name)
self._handle_custom_obj_in_normalized_args(node)
# _unique_targets must called after _update_args_for_unique and _update_kwargs_for_unique
self._unique_targets(node)
self._insert_node(position, node)
# update init-function-ast and construct-function-ast
if insert_to_ast:
node.set_func(ScopedValue.create_naming_value(node_name, "self"))
node_ast = node.get_ast()
if not isinstance(node_ast, ast.Assign):
raise RuntimeError("Only support insert cell op now")
AstModifier.insert_assign_to_function(self._init_func_ast,
targets=[ScopedValue(ValueType.NamingValue, "self", node_name)],
expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"),
args=[ScopedValue(ValueType.StringValue, "", node_name)])
AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
None if position is None else position.node.get_ast(),
position.before_node)
self._global_vars[node_name] = node.get_instance()
return node
def append_node(self, node: Node, append_to_ast: bool = True) -> Node:
"""
Append a node to SymbolTree.
Args:
node (Node): An instance of node to be appended.
append_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
True.
Returns:
An instance of node which has been appended to SymbolTree.
"""
return self.insert_node(Position.create(self, self._tail, False), node, append_to_ast)
def append_origin_field(self, node: Node) -> Node:
"""
Append an original field node to SymbolTree. An original field node represents a node created from existing
statement in forward method, from existing ast node in ast of forward method, so ast node do not need to update
while these nodes appending to SymbolTree.
This method is called while building SymbolTree usually.
Args:
node (Node): An instance of node to be appended.
Returns:
An instance of node which has been appended to SymbolTree.
"""
self._update_args_kwargs_for_unique(node)
if node.get_node_type() == NodeType.Output:
self._return = node
elif node.get_node_type() == NodeType.Input:
self._inputs.append(node)
return self.append_node(node, False)
def append_input_node(self, param_name: str, default: Optional[ScopedValue] = None):
"""
Append an input node to SymbolTree corresponding to parameter of forward method of network class.
This method is called while building SymbolTree usually.
Args:
param_name (str): A str represents name of parameter of forward method of network class.
default (Optional[ScopedValue] ): A ScopedValue represents default value of parameter. Default is None which
means parameter has no default value.
Returns:
An instance of input node which has been appended to SymbolTree.
"""
if param_name == "self":
return
for input_node in self._inputs:
targets = input_node.get_targets()
if len(targets) != 1:
raise RuntimeError("targets should have 1 elements")
target: ScopedValue = targets[0]
if target.type != ValueType.NamingValue:
raise RuntimeError("target.type should equal to ValueType.NamingValue")
if target.scope != "":
raise RuntimeError("target.scope should be empty")
exist_param = target.value
if exist_param == param_name:
raise RuntimeError("input duplicated:", param_name)
input_node = Node.create_input_node(None, param_name, default, name=f"input_{param_name}")
self.append_origin_field(input_node)
def try_append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST) -> Optional[Node]:
"""
Try appending a python node to SymbolTree if 'ast_node' is not None and 'ast_node' is not Empty if 'ast_node' is
a list or a dict.
This method is called while building SymbolTree usually.
Args:
ast_scope (ast.AST): A ast node represents ast node of scope of node.
ast_node (ast.AST): A ast node represents ast node.
Returns:
An instance of python node if a new node has been appended to SymbolTree else None.
"""
if ast_node is None:
return None
if isinstance(ast_node, (list, dict)) and not ast_node:
return None
return self.append_python_node(ast_scope, ast_node)
def append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST) -> Node:
"""
Append a python node to SymbolTree.
This method is called while building SymbolTree usually.
Args:
ast_scope (ast.AST): A ast node represents ast node of scope of node.
ast_node (ast.AST): A ast node represents ast node.
Returns:
An instance of python node which has been appended to SymbolTree.
"""
logger.info("Ignoring unsupported node(%s) in %s.", type(ast_node).__name__, type(ast_scope).__name__)
node_name = self._node_name_namer.get_name(type(ast_node).__name__)
node = Node.create_python_node(ast_node, node_name)
self._insert_node(Position.create(self, self._tail, True), node)
return node
def set_output(self, return_value: str, index: int) -> Node:
"""
Update return value of return of forward method of network class.
Args:
return_value (str): A str represents new return value.
index (int): A int indicates which return value to be updated.
Returns:
An instance of node represents return node after updated.
"""
if self._return is None:
raise RuntimeError("SymbolTree has no output")
self.set_node_arg(self._return, index, return_value)
return self._return
def erase_node(self, node_or_name: Union[Node, str]) -> Node:
"""
Erase a node from SymbolTree.
Note:
If node is depended on by other node, RuntimeError will raise.
Topological relation is updated.
Args:
node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
Returns:
An instance of node which has been erased from SymbolTree.
Raises:
RuntimeError: If 'node_or_name' is not in current SymbolTree.
RuntimeError: If erase corresponding ast node failed.
"""
node = self._get_real_node(node_or_name)
if node is None:
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
if not ret:
raise RuntimeError("node not in function ast tree.")
for key, value in self._nodes.items():
if id(value) == id(node):
self._nodes.pop(key)
value.isolate()
break
self._topo_mgr.on_erase_node(node)
return node
def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node:
"""
Insert a node-tree into SymbolTree.
Note:
Inputs of intra sub-tree nodes need to be welly set.
Inputs of inter sub-tree nodes will be updated by Rewrite automatically.
Args:
position (Position): A Position indicates an insert position point.
root (Node): An instance of node as root of node-tree to be inserted in.
insert_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
True.
Returns:
An instance of node as root node of node-tree which has been inserted into SymbolTree.
Raises:
RuntimeError: If 'position' is not in current SymbolTree.
"""
# if position not in current SymbolTree
if position.symbol_tree is not self:
raise RuntimeError("Position is not in current SymbolTree: ", position)
queue: [Node] = [root]
todos: [] = []
inputs_list: [] = []
while queue:
cur_node = queue.pop(0)
if cur_node in todos:
continue
todos.append(cur_node)
node_inputs = cur_node.get_inputs()
inputs_list.append(node_inputs)
for node_input in node_inputs:
if node_input is not None:
queue.append(node_input)
todos.reverse()
inputs_list.reverse()
for index, todo in enumerate(todos):
self.insert_node(position, todo, insert_to_ast)
position = self.after(todo)
# relink input of node
original_inputs = inputs_list[index]
for arg_idx, original_input in enumerate(original_inputs):
if original_input is not None:
self.set_node_arg_by_node(todo, arg_idx, original_input)
return root
@staticmethod
def _link_nodes_and_find_root(nodes: [Node]) -> Node:
"""
Find inputs for all nodes created by Replacement according to their targets and arguments.
Find root node of all nodes created by Replacement. One and Only one root should be found.
Args:
nodes ([Node]): A list of instance of Node created by Replacement.
Returns:
An instance of Node represents root of input nodes.
"""
consumers: [ScopedValue] = []
target_dict: {ScopedValue: Node} = {}
for node in nodes:
consumers.extend(node.get_args())
for _, arg in node.get_kwargs():
consumers.append(arg)
for target in node.get_targets():
if target_dict.get(target) is not None:
raise RuntimeError("Target of node duplicated")
target_dict[target] = node
# find root node
root = None
for node in nodes:
used = 0
for target in node.get_targets():
if target in consumers:
used += 1
if used == 0:
if root is not None:
raise RuntimeError("Replacement should only has one root")
root = node
if root is None:
raise RuntimeError("No root node found in replacement nodes")
# link node's input
for node in nodes:
inputs = []
for _, arg in node.get_normalized_args().items():
node_input: Node = target_dict.get(arg)
if node_input is None:
inputs.append(None)
else:
inputs.append(node_input)
node.set_inputs(inputs)
return root
def replace(self, old_node: Node, new_nodes: [Node]) -> Node:
"""
Replace an old_node with a node_tree. 'new_node' is the root node of the node_tree.
Note:
Rewrite will iterate all nodes linked to this root node and insert these nodes into symbol_tree.
Inputs of intra sub-tree nodes need to be welly set.
Inputs of inter sub-tree nodes will be updated by Rewrite automatically.
Args:
old_node (Node): Node to be replaced.
new_nodes ([Node]): Node tree to replace in.
Returns:
An instance of Node represents root of node_tree been replaced in.
Raises:
RuntimeError: If 'old_node' is isolated.
RuntimeError: If 'old_node' is not belong to current SymbolTree.
"""
real_old_node = self._get_real_node(old_node)
if real_old_node is None:
raise RuntimeError("Old node is not belong to current SymbolTree:", old_node)
# get position
next_node: Node = old_node.get_next()
prev_node: Node = old_node.get_prev()
if prev_node is None and next_node is None:
raise RuntimeError("Try replacing a isolated node: ", old_node)
if next_node is None:
position = self.after(prev_node)
else:
position = self.before(next_node)
# insert node first, because targets of new_node is determined after insert
new_tree_root = SymbolTree._link_nodes_and_find_root(new_nodes)
new_node = self._insert_tree(position, new_tree_root)
# use targets of insert tree to redirect edge
users = self.get_node_users(old_node)
if len(new_node.get_targets()) != 1:
raise RuntimeError("targets of new_node should have 1 elements")
for user in users:
self.set_node_arg_by_node(user[0], user[1], new_node)
# erase old_node after edge is redirected because node can be erased only when node is isolated topologically
self.erase_node(old_node)
return new_node
def set_node_arg(self, node: Union[Node, str], index: int, arg: Union[ScopedValue, str]):
"""
Set argument of 'node'.
Args:
node (Union[Node, str]): Node to be modified. Can be a node or name of node.
index (int): Indicate which input being modified.
arg (Union[ScopedValue, str]): New argument to been set.
Raises:
RuntimeError: If 'node' is not belong to current SymbolTree.
"""
real_node = self._get_real_node(node)
if real_node is None:
raise RuntimeError("Node is not belong to current SymbolTree: ", node)
new_arg, old_arg = node.set_arg(arg, index)
self._topo_mgr.on_update_arg(node, index, old_arg, new_arg)
def set_node_arg_by_node(self, dst_node: Union[Node, str], arg_idx: int, src_node: Union[Node, str],
out_idx: Optional[int] = None):
"""
Set argument of 'dst_node' by another Node.
Args:
dst_node (Node): Node to be modified. Can be a node or name of node.
arg_idx (int): Indicate which input being modified.
src_node (Node): Node as new input. Can be a node or name of node.
out_idx (Optional[int]): Indicate which output of 'src_node' as new input of 'dst_node'. Default is None
which means use first output of 'node_to_link' as new input.
Raises:
RuntimeError: If 'dst_node' is not belong to current SymbolTree.
RuntimeError: If 'src_node' is not belong to current SymbolTree.
RuntimeError: If 'out_idx' is out of range.
RuntimeError: If 'src_node' has multi-outputs while 'out_idx' is None or 'out_idx' is not offered.
"""
real_dst_node = self._get_real_node(dst_node)
if real_dst_node is None:
raise RuntimeError("dst_node is not belong to current SymbolTree: ", dst_node)
real_src_node = self._get_real_node(src_node)
if real_src_node is None:
raise RuntimeError("src_node is not belong to current SymbolTree: ", src_node)
targets = real_src_node.get_targets()
if out_idx is None:
if len(targets) != 1:
raise RuntimeError("node should has one output when out_idx is not provided")
out_idx = 0
if out_idx >= len(targets):
raise RuntimeError("out_idx out of range: ", out_idx)
new_arg = targets[out_idx]
self.set_node_arg(real_dst_node, arg_idx, new_arg)
def dump(self):
"""Dump graph."""
dump_st = SymbolTreeDumper(self)
dump_st.dump()
def get_code(self) -> str:
"""
Get source code of modified network.
Returns:
A str represents source code of modified network.
"""
ast.fix_missing_locations(self._module_ast)
return astunparse.unparse(self._module_ast)
def get_network(self):
"""
Get modified network.
Returns:
A network object.
"""
cls = self._get_cls_through_file()
return cls(self._global_vars)
def _unique_targets(self, node: Node):
"""
Unique targets of node by _target_namer.
Args:
node (Node): A Node whose targets to be uniqued.
"""
new_targets: [ScopedValue] = []
if node.get_targets() is None:
return
for target in node.get_targets():
if not isinstance(target, ScopedValue):
raise TypeError("target should be ScopedValue, got: ", type(target))
unique_target = self._target_namer.get_name(target.value)
new_targets.append(ScopedValue.create_naming_value(unique_target, target.scope))
node.set_targets(new_targets)
def _update_args_kwargs_for_unique(self, node: Node):
"""
Update arguments and keyword arguments of node because unique-ing of targets of other nodes.
Args:
node (Node): A Node whose arguments and keyword arguments to be updated.
"""
result: {str: ScopedValue} = {}
if node.get_normalized_args() is None:
return
for key, arg in node.get_normalized_args().items():
if not isinstance(arg, ScopedValue):
raise TypeError("arg should be ScopedValue, got: ", type(arg))
if arg.type == ValueType.NamingValue:
# unique name
new_arg = ScopedValue(arg.type, arg.scope, self._target_namer.get_real_arg(arg.value))
result[key] = new_arg
else:
result[key] = arg
node.set_normalized_args(result)
def _add_node2nodes(self, node: Node):
"""
Add `node` to `_nodes` dict.
Args:
node (Node): A Node to be added into `_nodes`.
Raises:
RuntimeError: If name of the node is duplicated.
"""
node_name = node.get_name()
if self._nodes.get(node_name) is not None:
raise RuntimeError("generated duplicated node name", node_name, self._nodes.get(node_name),
node)
self._nodes[node_name] = node
def _insert_node(self, position: Optional[Position], node: Node):
"""
Insert a node into SymbolTree.
1. Add `node` to `_nodes`.
2. Insert `node` to node list(source code order).
3. Update topological relation and update inputs of `node`.
Args:
position (Optional[Position]): Indicates node insert position. Position is None when inserting first node of
SymbolTree.
node (Node): A Node to be inserted into SymbolTree.
Raises:
RuntimeError: Position is None when _nodes of SymbolTree is not Empty. It means position can not be None
unless inserting first node.
"""
if position is None:
if self._nodes:
raise RuntimeError("self._nodes should be empty")
self._head = node
else:
if position.before_node:
position.node.insert_before(node)
else:
position.node.insert_after(node)
self._tail = node
self._add_node2nodes(node)
self._topo_mgr.on_insert_node(node)
node.set_belong_symbol_tree(self)
def _handle_custom_obj_in_normalized_args(self, node: Node):
"""
Convert CustomObjValue type argument to NamingValue type argument by storing custom object in global_vars dict.
Args:
node (Node): A Node whose arguments and keyword arguments to be handled.
"""
result: {str, ScopedValue} = {}
for arg, value in node.get_normalized_args().items():
if not isinstance(value, ScopedValue):
raise TypeError("value should be ScopedValue, got: ", type(value))
if value.type == ValueType.CustomObjValue:
field = self._node_name_namer.get_name(f"var_{type(value.value).__name__}")
self._global_vars[field] = value.value
init_targets = [ScopedValue.create_naming_value(field, "self")]
AstModifier.append_global_vars_expr_to_init(self._init_func_ast, init_targets, field)
result[arg] = init_targets[0]
else:
result[arg] = value
node.set_normalized_args(result)
def _get_cls_through_file(self):
"""
Load rewritten network class of current SymbolTree.
1. Get source code of current SymbolTree.
2. Saving source code to a tempfile.
3. Import rewritten network class using "__import__" function.
Returns:
A class handle.
"""
source = self.get_code()
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.py')
tmp_file.write(source.encode('utf8'))
tmp_file.flush()
tmp_file_name = tmp_file.name
tmp_module_path, tmp_module_file = os.path.split(tmp_file_name)
tmp_module_name = tmp_module_file[:-3]
sys.path.append(tmp_module_path)
tmp_module = __import__(tmp_module_name)
network_cls = getattr(tmp_module, self._opt_cls_name)
if network_cls is None:
raise RuntimeError("Can not find network class:", self._opt_cls_name)
return network_cls

View File

@ -0,0 +1,73 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SymbolTree builder."""
from typing import Optional
import ast
import inspect
from mindspore.nn import Cell
from .symbol_tree import SymbolTree
from .parser_register import ParserRegister
from .parser import Parser
from .ast_transformers import FlattenRecursiveStmt
class SymbolTreeBuilder:
"""SymbolTree builder."""
def __init__(self, network: Cell):
"""
Constructor of SymbolTreeBuilder.
Args:
network (Cell): An instance of Cell represents a network from which SymbolTree will be built.
"""
if not isinstance(network, Cell):
raise RuntimeError("Only support network with Cell type now")
self._origin_net = network
network_str = inspect.getsource(type(network))
self._ast_root: ast.Module = ast.parse(network_str)
self._root_tree: Optional[SymbolTree] = None
@staticmethod
def _ast_transform(ast_root: ast.AST) -> ast.AST:
"""
Optimize ast before parse.
Args:
ast_root (ast.AST): An instance of ast to be optimized.
Returns:
An instance of ast been optimized.
"""
transform_list = [FlattenRecursiveStmt()]
for transformer in transform_list:
ast_root = transformer.transform(ast_root)
return ast_root
def build(self) -> SymbolTree:
"""
Build SymbolTree.
Returns:
An instance of SymbolTree.
"""
self._ast_root = SymbolTreeBuilder._ast_transform(self._ast_root)
if not isinstance(self._ast_root, ast.Module):
raise RuntimeError("ast_root should be a ast.Module")
self._root_tree: SymbolTree = SymbolTree(self._origin_net, self._ast_root)
parser: Parser = ParserRegister.instance().get_parser(ast.Module)
parser.process(self._root_tree, self._ast_root)
return self._root_tree

View File

@ -0,0 +1,144 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SymbolTree dumper."""
import inspect
from mindspore import log as logger
from .node import Node
from .api.node_type import NodeType
from .api.scoped_value import ScopedValue, ValueType
class SymbolTreeDumper:
"""SymbolTree dumper."""
def __init__(self, symbol_tree):
"""
Constructor of SymbolTreeDumper.
Args:
symbol_tree (SymbolTree): An instance of SymbolTree to be dumped.
"""
self._symbol_tree = symbol_tree
self._dump_buffer = ""
self._dump_key2index = {}
def _reset(self):
"""Reset SymbolTreeDumper."""
self._dump_buffer = ""
self._dump_key2index = {}
def _dump_global_info(self):
"""Dump global info of SymbolTree."""
self._dump_buffer += f"#SymbolTree entry : @construct \n"
def _dump_inputs(self):
"""Dump inputs of SymbolTree."""
inputs = self._symbol_tree.get_inputs()
self._dump_buffer += f"#Inputs num : {len(inputs)}\n"
for single_input in inputs:
targets = single_input.get_targets()
if len(targets) != 1:
raise RuntimeError("Only support one output per node now")
target: ScopedValue = targets[0]
if target.type != ValueType.NamingValue:
raise RuntimeError("target.type should equal to ValueType.NamingValue")
if target.scope != "":
raise RuntimeError("target.scope should be empty")
input_arg = target.value
input_name = f"%input_{input_arg}"
if input_arg in self._dump_key2index.keys():
raise RuntimeError("input_arg duplicated: ", input_arg)
self._dump_key2index[input_arg] = input_name
self._dump_buffer += f"{input_name}\n"
self._dump_buffer += f"\n"
def _dump_nodes(self):
"""Dump nodes of SymbolTree."""
self._dump_buffer += f"Symbol Tree @construct {{ \n"
node_no = -1
node: Node = self._symbol_tree.get_head_node().get_next()
while node is not None:
if node.get_node_type() is NodeType.Output:
self._dump_buffer += f" Return(%{node_no}) \n"
self._dump_buffer += f" : (null) \n"
self._dump_buffer += f" # In file {inspect.getfile(type(self._symbol_tree.get_origin_network()))}"
node = node.get_next()
continue
node_no += 1
self._dump_key2index[node.get_name()] = f"%{node_no}"
targets = node.get_targets()
if not targets:
targets = [None]
op_type = node.get_instance_type()
if hasattr(op_type, "__name__"):
op_type_name = op_type.__name__
else:
if hasattr(type(op_type), "__name__"):
op_type_name = type(op_type).__name__
else:
raise RuntimeError("op has no attr __name__")
self._dump_buffer += f" %{node_no}({targets[0]}) = {op_type_name}"
args = node.get_normalized_args().values()
if args:
arg_str = f""
for arg in args:
if isinstance(arg, str):
arg_name = arg
elif isinstance(arg, ScopedValue):
arg_name = arg.value
else:
raise RuntimeError(f"arg type {type(arg)} of {arg} is not supported now")
if arg_name in self._dump_key2index.keys():
arg_str += f"{self._dump_key2index[arg_name]}, "
else:
logger.warning("arg not appears before")
arg_str += f"{arg_name}, "
self._dump_buffer += f"({arg_str[:-2]})"
self._dump_buffer += f"{{instance name: {node.get_name()}}}"
self._dump_buffer += f" attributes {{"
attrs = node.get_attributes()
if attrs:
attrs_str = f""
for attr in attrs:
if not isinstance(attr, str):
raise TypeError("attr should be str, got: ", type(attr))
attrs_str += f"{attr}: {attrs[attr]}, "
self._dump_buffer += attrs_str[:-2]
self._dump_buffer += f"}}\n"
self._dump_buffer += f" : (null) -> (null)\n"
cls_real_path = inspect.getfile(node.get_instance_type()) if node.get_instance() else None
self._dump_buffer += f" # In file {cls_real_path}\n"
self._dump_buffer += f" # In file {inspect.getfile(type(self._symbol_tree.get_origin_network()))}\n"
node = node.get_next()
self._dump_buffer += f"}}\n"
def dump(self):
"""Dump SymbolTree."""
self._reset()
self._dump_global_info()
self._dump_inputs()
self._dump_nodes()
print(self._dump_buffer)

View File

@ -0,0 +1,189 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SymbolTree topological-relationship manager."""
from typing import Tuple
from .api.scoped_value import ScopedValue
from .node import Node
class TopoManager:
"""SymbolTree topological-relationship manager."""
def __init__(self):
"""
Constructor of TopoManager.
Init provider and consumer.
Key of dict is an instance of ScopedValue.
Value of dict is a tuple whose first is an instance of Node, whose second is an index.
It means node's index th arg is argument
"""
self._target_provider: {ScopedValue, (Node, int)} = {}
self._target_consumer: {ScopedValue, [(Node, int)]} = {}
def get_node_users(self, node: Node) -> [Tuple[Node, int]]:
"""
Get all nodes which depend on node corresponding to node_or_name.
Args:
node (Node): An instance of node.
Returns:
A list of nodes represents node users.
"""
targets = node.get_targets()
results = []
for target in targets:
consumers = self._target_consumer.get(target)
if consumers is None:
continue
results.extend(consumers)
unique_results = []
for result in results:
if result not in unique_results:
unique_results.append(result)
return unique_results
def _add_consumer(self, product: ScopedValue, consumer: Node, index):
"""
Add a consumer to consumer dict.
Args:
product (ScopedValue): An instance of ScopedValue represents product to be consumed.
consumer (Node): An instance of Node represents consumer.
index (int): A int represents which input of consumer is the product.
"""
consumers = self._target_consumer.get(product)
if consumers is None:
self._target_consumer[product] = [(consumer, index)]
else:
self._target_consumer[product].append((consumer, index))
def _erase_provider(self, product: ScopedValue):
"""
Erase a provider from provider dict.
Args:
product (ScopedValue): An instance of ScopedValue represents product to be erased.
"""
if self._target_provider.get(product) is not None:
self._target_provider.pop(product)
def _erase_consumer(self, product: ScopedValue, consumer: Node):
"""
Erase a consumer from consumer dict.
Args:
product (ScopedValue): An instance of ScopedValue represents product whose consumer would be erased.
consumer (Node): An instance of Node which would be erased as a consumer.
"""
consumers = self._target_consumer.get(product)
if consumers is None:
return
for i in range(len(consumers) - 1, -1, -1):
exist_ele = consumers[i]
if id(exist_ele[0]) == id(consumer):
consumers.pop(i)
def _update_node_inputs(self, node: Node) -> [Node]:
"""
Update inputs of node by current provider dict and consumer dict.
Args:
node (Node): An instance of Node whose inputs will be updated.
Returns:
A list of instance of nodes represents inputs of node.
"""
if node.get_normalized_args() is None:
node.set_inputs([])
return []
inputs = []
for arg in node.get_normalized_args().values():
provider = self._target_provider.get(arg)
# some arg of some node may be self.xxx which is not an output of another node
if provider is not None:
inputs.append(provider[0])
node.set_inputs(inputs)
return inputs
def on_insert_node(self, node: Node):
"""
Update provider dict and consumer dict while inserting node into SymbolTree and update inputs of node by updated
provider dict and consumer dict.
Args:
node (Node): An instance of Node which been inserted into SymbolTree.
"""
if node.get_targets() is not None:
for i in range(0, len(node.get_targets())):
target = node.get_targets()[i]
if self._target_provider.get(target) is not None:
raise RuntimeError("target duplicated:", target)
self._target_provider[target] = (node, i)
if node.get_normalized_args() is not None:
for index, arg in enumerate(node.get_normalized_args().values()):
self._add_consumer(arg, node, index)
self._update_node_inputs(node)
def on_erase_node(self, node: Node):
"""
Update provider dict and consumer dict while erasing node from SymbolTree.
Args:
node (Node): An instance of Node which been erased from SymbolTree.
"""
if node.get_targets() is not None:
for target in node.get_targets():
consumers = self._target_consumer.get(target)
if consumers is not None and consumers:
raise RuntimeError("Only support erase isolated node: ", node.get_name(), target)
self._erase_provider(target)
if node.get_normalized_args() is not None:
for arg in node.get_normalized_args().values():
self._erase_consumer(arg, node)
# clear inputs of node rather than call _update_node_inputs because node is already erase from consumer dict
node.set_inputs([])
def on_update_arg(self, node: Node, arg_idx: int, old_arg: ScopedValue, new_arg: ScopedValue):
"""
Update provider dict and consumer dict while updating argument of node and update inputs of node by updated
provider dict and consumer dict.
Args:
node (Node): An instance of Node whose arguments being updated.
arg_idx (int): An int indicates which argument of node being updated.
old_arg (ScopedValue): An instance of ScopedValue represents original argument.
new_arg (ScopedValue): An instance of ScopedValue represents new argument.
"""
self._erase_consumer(old_arg, node)
self._add_consumer(new_arg, node, arg_idx)
self._update_node_inputs(node)
def dump(self, title=""):
"""
Dump topological relation.
Args:
title (str): A string as a title will be printed before dumping topological relation.
"""
print(f"{title}------------------------------------------------------------------------------------")
for k, v in self._target_provider.items():
print(f"{v[0].get_name()} produces {k.value}")
for k, v in self._target_consumer.items():
print(f"{k.value} is consumed by: ")
for ele in v:
print(ele[0].get_name())
print(f"-----------------------------------------------------------------------------------------")

View File

@ -12,3 +12,5 @@ pycocotools >= 2.0.2 # for st test
tables >= 3.6.1 # for st test
easydict >= 1.9 # for st test
psutil >= 5.7.0
astunparse >= 0.0
astpretty >= 0.0

View File

View File

@ -0,0 +1,64 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import ast
import inspect
from mindspore.nn import Cell, Conv2d, BatchNorm2d, ReLU
from mindspore.rewrite.ast_transformers.flatten_recursive_stmt import FlattenRecursiveStmt
class Network(Cell):
def __init__(self):
super().__init__()
self.conv = Conv2d(16, 16, 3)
self.bn = BatchNorm2d(16)
self.relu1 = ReLU()
self.relu2 = ReLU()
self.relu3 = ReLU()
def construct(self, x):
x = self.conv(x + 1)
x = x + 1 * 5 + 4 / 2 + self.bn(x)
self.relu1(x * 5)
x = self.relu2(x + 1)
x = True and x or x
x = self.relu3(x)
return x + 3
def _get_ast():
source = inspect.getsource(Network)
return ast.parse(source)
def test_flatten():
"""
Feature: Class FlattenRecursiveStmt.
Description: Apply FlattenRecursiveStmt on a simple network.
Expectation: Success.
"""
ast_node = _get_ast()
frs = FlattenRecursiveStmt()
frs.transform(ast_node)
assert len(ast_node.body) == 1
ast_class = ast_node.body[0]
assert isinstance(ast_class, ast.ClassDef)
assert len(ast_class.body) == 2
ast_init_func = ast_class.body[0]
assert isinstance(ast_init_func, ast.FunctionDef)
assert len(ast_init_func.body) == 6
ast_construct_func = ast_class.body[1]
assert isinstance(ast_construct_func, ast.FunctionDef)
assert len(ast_construct_func.body) == 17

View File

@ -0,0 +1,141 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore
from mindspore import Tensor, nn
from mindspore.rewrite import SymbolTree, ScopedValue, ValueType, Node
from mindspore.common.initializer import Normal
from mindspore.common.api import _cell_graph_executor
class SimpleNet(nn.Cell):
def __init__(self, num_class=10, num_channel=1):
super(SimpleNet, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.var = 10
def construct(self, x):
x = self.conv1(x)
x = x
y = self.var
y = y * 5
y = y and True
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
class MyCell(nn.Cell):
def __init__(self):
super().__init__()
self.conv = nn.Dense(5, 16)
def construct(self, x, y):
x = self.conv(x)
x = mindspore.ops.Add()(x, y)
return x
def add_conv_before_flatten(stree: SymbolTree):
new_conv_node = None
for node in stree.nodes():
if node.get_instance_type() == mindspore.nn.Flatten:
position = stree.before(node)
new_conv = nn.Conv2d(16, 16, 3)
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
args=[ScopedValue.create_naming_value('self_max_po')])
stree.insert(position, new_conv_node)
break
if new_conv_node is not None:
for node in stree.nodes():
if node.get_instance_type() == mindspore.nn.Flatten:
inputs = node.get_inputs()
assert len(inputs) == 1
new_conv_node.set_arg_by_node(0, inputs[0])
def add_my_cell_after_x_12(stree: SymbolTree):
for node in stree.nodes():
targets = node.get_targets()
if targets is None:
continue
assert targets[0].type == ValueType.NamingValue
target = str(targets[0])
if target == "x_12":
position = stree.after(node)
custom_cell = MyCell()
bias = Tensor(1, mindspore.int32)
new_custom_node = Node.create_call_cell(custom_cell, targets=['nx2'],
args=[ScopedValue.create_naming_value('nx3'),
ScopedValue.create_variable_value(bias)], name='my_cell')
stree.insert(position, new_custom_node)
new_custom_node.set_arg(0, "x_12")
break
def erase_node_x_11(stree: SymbolTree):
return_node = None
for node in stree.nodes():
if node.get_targets() is None:
return_node = node
break
assert return_node is not None
for node in stree.nodes():
targets = node.get_targets()
if targets is None:
continue
assert targets[0].type == ValueType.NamingValue
target = str(targets[0])
if target == "x_11":
stree.set_output(0, "x_10")
stree.erase_node(node)
break
def transform(stree: SymbolTree):
add_conv_before_flatten(stree)
add_my_cell_after_x_12(stree)
erase_node_x_11(stree)
def test_simple_net():
"""
Feature: Module rewrite.
Description: Resolve a simple network by rewrite and do some transform on it.
Expectation: Result of rewrite can be compiled.
"""
net = SimpleNet(10)
stree = SymbolTree(net)
transform(stree)
print("------------------------------------ keys of global_vars: ",
getattr(stree.get_handler(), "_global_vars").keys())
net_opt = stree.get_network()
data_in = Tensor(np.ones([1, 1, 32, 32]), mindspore.float32)
_cell_graph_executor.compile(net_opt, data_in)

View File

@ -0,0 +1,126 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import ast
from mindspore.nn import Cell
from mindspore.rewrite import ScopedValue
from mindspore.rewrite.node import Node
class FakeCell(Cell):
def construct(self, input1, input2, cool_boy=None):
return input1 + input2 + cool_boy
class FakeCell2(Cell):
def construct(self, a, b, d, e, *args, f=6, **kwargs):
return a + b + d + e + sum(args) + f + sum(kwargs.values())
class FakeCell3(Cell):
def construct(self, a, b, *args, f=6, h=7, **kwargs):
return a + b + f + h + sum(args) + sum(kwargs.values())
def test_create_by_cell():
"""
Feature: Python api create_call_cell of Node of Rewrite.
Description: Call create_call_cell to create a node.
Expectation: Success.
"""
node = Node.create_call_cell(FakeCell(), None, ['x'], 'new_conv',
[ScopedValue.create_naming_value('x'), ScopedValue.create_variable_value(1)],
{"cool_boy": ScopedValue.create_naming_value('Naroto')}, 'new_conv')
assert node._args_num == 2
assert node._kwargs_num == 1
assert node._normalized_args_keys == ["input1", "input2", "cool_boy"]
assert node._normalized_args == {
"input1": ScopedValue.create_naming_value('x'),
"input2": ScopedValue.create_variable_value(1),
"cool_boy": ScopedValue.create_naming_value('Naroto')
}
ast_node: ast.Assign = node.get_ast()
assign_value: ast.Call = ast_node.value
args_ast = assign_value.args
keywords_ast = assign_value.keywords
assert len(args_ast) == 2
assert len(keywords_ast) == 1
assert keywords_ast[0].arg == "cool_boy"
assert isinstance(args_ast[0], ast.Name)
assert args_ast[0].id == "x"
assert isinstance(args_ast[1], ast.Constant)
assert args_ast[1].value == 1
keyword_value_3 = keywords_ast[0].value
assert isinstance(keyword_value_3, ast.Name)
assert keyword_value_3.id == "Naroto"
node.set_arg(ScopedValue.create_variable_value(2), 1)
assert isinstance(node.get_normalized_args().get("input2"), ScopedValue)
assert node.get_normalized_args().get("input2").value == 2
ast_node: ast.Assign = node.get_ast()
assign_value: ast.Call = ast_node.value
args_ast = assign_value.args
assert args_ast[1].value == 2
args = node.get_args()
assert args == [ScopedValue.create_naming_value('x'), ScopedValue.create_variable_value(2)]
kwargs = node.get_kwargs()
assert kwargs == {"cool_boy": ScopedValue.create_naming_value('Naroto')}
def test_create_by_cell2():
"""
Feature: Python api create_call_cell of Node of Rewrite.
Description: Call create_call_cell to create a node.
Expectation: Success.
"""
node = Node.create_call_cell(FakeCell2(), None, ['x'], 'new_conv',
[ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x"),
ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x"),
ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x")],
{"cool_boy": ScopedValue.create_naming_value('Naroto')}, 'new_conv')
assert node.get_normalized_args() == {
"a": ScopedValue.create_naming_value('x'),
"b": ScopedValue.create_naming_value('x'),
"d": ScopedValue.create_naming_value('x'),
"e": ScopedValue.create_naming_value('x'),
"args_4": ScopedValue.create_naming_value('x'),
"args_5": ScopedValue.create_naming_value('x'),
"cool_boy": ScopedValue.create_naming_value('Naroto'),
}
def test_create_by_cell3():
"""
Feature: Python api create_call_cell of Node of Rewrite.
Description: Call create_call_cell to create a node.
Expectation: Success.
"""
node = Node.create_call_cell(FakeCell3(), None, ['x'], 'new_conv',
[ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x"),
ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x")],
{"h": ScopedValue.create_naming_value(1), "f": ScopedValue.create_naming_value(2),
"cool_boy": ScopedValue.create_naming_value('Naroto')}, 'new_conv')
assert node.get_normalized_args() == {
"a": ScopedValue.create_naming_value('x'),
"b": ScopedValue.create_naming_value('x'),
"args_2": ScopedValue.create_naming_value('x'),
"args_3": ScopedValue.create_naming_value('x'),
"f": ScopedValue.create_naming_value(2),
"h": ScopedValue.create_naming_value(1),
"cool_boy": ScopedValue.create_naming_value('Naroto'),
}

View File

@ -0,0 +1,665 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import ast
from collections import OrderedDict
from mindspore.nn import Cell, Conv2d, BatchNorm2d, ReLU
from mindspore.ops import Add, AddN
from mindspore.rewrite import ScopedValue, Node, SymbolTree
from mindspore.rewrite import PatternEngine, PatternNode, Replacement, VarNode
def test_tree_pattern_match():
"""
Feature: Python api PatternEngine.
Description: Construct a tree PatternEngine and apply it on a SymbolTree, check SymbolTree after PatternEngine
applied.
Expectation: Success.
"""
assert True
def test_leak_pattern_match():
"""
Feature: Python api PatternEngine.
Description: Construct a leaked tree PatternEngine and apply it on a SymbolTree, check SymbolTree after
PatternEngine applied.
Expectation: Failure.
"""
assert True
class ChainNetwork(Cell):
def __init__(self):
super().__init__()
self.conv = Conv2d(16, 16, 3)
self.bn = BatchNorm2d(16)
self.relu1 = ReLU()
self.relu2 = ReLU()
self.relu3 = ReLU()
def construct(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu1(x)
x = self.relu2(x)
x = self.relu3(x)
return x
def test_one_to_one_pattern():
"""
Feature: Python api PatternEngine.
Description: Construct a one-to-one PatternEngine and apply it on a SymbolTree, check SymbolTree after PatternEngine
applied.
Expectation: Success.
"""
class BnReplacement(Replacement):
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
assert is_chain_pattern
assert pattern.type() == BatchNorm2d
bn_node: Node = matched.get(pattern.name())
assert bn_node is not None
conv = Conv2d(16, 16, 3)
conv_node = Node.create_call_cell(conv, ['x1'], bn_node.get_args(), bn_node.get_kwargs())
return [conv_node]
class BnReplace(PatternEngine):
def __init__(self):
super().__init__([BatchNorm2d], BnReplacement())
net = ChainNetwork()
stree = SymbolTree(net)
conv = stree.get_node("conv")
bn = stree.get_node("bn")
relu1 = stree.get_node("relu1")
construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
assert conv is not None
assert bn is not None
assert relu1 is not None
assert len(construct_ast.body) == 6
assert len(stree.nodes()) == 7
bn_replace = BnReplace()
bn_replace.apply(stree)
assert len(construct_ast.body) == 6
assert len(stree.nodes()) == 7
conv = stree.get_node("conv")
bn = stree.get_node("bn")
relu1 = stree.get_node("relu1")
new_conv = stree.get_node("x1")
assert conv is not None
assert bn is None
assert relu1 is not None
assert new_conv is not None
# check conv topological order
assert len(conv.get_users()) == 1
assert conv.get_users()[0] == new_conv
# check new_conv topological order
assert len(new_conv.get_inputs()) == 1
assert new_conv.get_inputs()[0] == conv
assert len(new_conv.get_users()) == 1
assert new_conv.get_users()[0] == relu1
# check source code order
assert getattr(conv.get_handler(), "_next") == new_conv.get_handler()
assert getattr(new_conv.get_handler(), "_next") == relu1.get_handler()
assert getattr(relu1.get_handler(), "_prev") == new_conv.get_handler()
assert getattr(new_conv.get_handler(), "_prev") == conv.get_handler()
# # check arg edge
assert len(conv.get_targets()) == 1
assert len(new_conv.get_args()) == 1
assert conv.get_targets()[0] == new_conv.get_args()[0]
assert len(new_conv.get_targets()) == 1
assert len(relu1.get_args()) == 1
assert new_conv.get_targets()[0] == relu1.get_args()[0]
def test_one_to_multi_chain_pattern():
"""
Feature: Python api PatternEngine.
Description: Construct a one-to-multi PatternEngine and apply it on a SymbolTree, check SymbolTree after
PatternEngine applied.
Expectation: Success.
"""
class BnReplacement(Replacement):
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
assert is_chain_pattern
assert pattern.type() == BatchNorm2d
bn_node: Node = matched.get(pattern.name())
assert bn_node is not None
# Replacement should ensure target is unique in result
# Replacement should ensure args and kwargs are well set by topological relation
conv1 = Conv2d(16, 16, 3)
conv_node1 = Node.create_call_cell(conv1, ['x1'], bn_node.get_args(), bn_node.get_kwargs())
conv2 = Conv2d(16, 16, 5)
conv_node2 = Node.create_call_cell(conv2, ['x2'], [ScopedValue.create_naming_value('x1')])
return [conv_node1, conv_node2]
class BnReplace(PatternEngine):
def __init__(self):
super().__init__([BatchNorm2d], BnReplacement())
net = ChainNetwork()
stree = SymbolTree(net)
conv = stree.get_node("conv")
bn = stree.get_node("bn")
relu1 = stree.get_node("relu1")
construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
assert conv is not None
assert bn is not None
assert relu1 is not None
assert len(construct_ast.body) == 6
assert len(stree.nodes()) == 7
bn_replace = BnReplace()
bn_replace.apply(stree)
assert len(construct_ast.body) == 7
assert len(stree.nodes()) == 8
conv = stree.get_node("conv")
bn = stree.get_node("bn")
relu1 = stree.get_node("relu1")
new_conv1 = stree.get_node("x1")
new_conv2 = stree.get_node("x2")
assert conv is not None
assert bn is None
assert relu1 is not None
assert new_conv1 is not None
assert new_conv2 is not None
# check conv topological order
assert len(conv.get_users()) == 1
assert conv.get_users()[0] == new_conv1
# check new_conv1 topological order
assert len(new_conv1.get_inputs()) == 1
assert new_conv1.get_inputs()[0] == conv
assert len(new_conv1.get_users()) == 1
assert new_conv1.get_users()[0] == new_conv2
# check new_conv2 topological order
assert len(new_conv2.get_inputs()) == 1
assert new_conv2.get_inputs()[0] == new_conv1
assert len(new_conv2.get_users()) == 1
assert new_conv2.get_users()[0] == relu1
# check source code order
assert getattr(conv.get_handler(), "_next") == new_conv1.get_handler()
assert getattr(new_conv1.get_handler(), "_next") == new_conv2.get_handler()
assert getattr(new_conv2.get_handler(), "_next") == relu1.get_handler()
assert getattr(relu1.get_handler(), "_prev") == new_conv2.get_handler()
assert getattr(new_conv2.get_handler(), "_prev") == new_conv1.get_handler()
assert getattr(new_conv1.get_handler(), "_prev") == conv.get_handler()
# check arg edge
assert len(conv.get_targets()) == 1
assert len(new_conv1.get_args()) == 1
assert conv.get_targets()[0] == new_conv1.get_args()[0]
assert len(new_conv1.get_targets()) == 1
assert len(new_conv2.get_args()) == 1
assert new_conv1.get_targets()[0] == new_conv2.get_args()[0]
assert len(new_conv2.get_targets()) == 1
assert len(relu1.get_args()) == 1
assert new_conv2.get_targets()[0] == relu1.get_args()[0]
class TreeNetwork(Cell):
def __init__(self):
super().__init__()
self.conv1 = Conv2d(16, 16, 3)
self.conv2 = Conv2d(16, 16, 5)
self.add = Add()
self.relu = ReLU()
self.relu1 = ReLU()
self.relu2 = ReLU()
def construct(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x)
x = self.add(x1, x2)
x = self.relu(x)
x1 = self.relu1(x)
x2 = self.relu2(x)
x = self.add(x1, x2)
return x
def test_tree_pattern():
"""
Feature: Python api PatternEngine.
Description: Construct a multi-to-multi PatternEngine and apply it on a SymbolTree, check SymbolTree after
PatternEngine applied.
Expectation: Success.
"""
class AddReluReplacement(Replacement):
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
assert is_chain_pattern
assert pattern.type() == ReLU
relu_node: Node = matched.get(pattern.name())
assert relu_node is not None
assert len(pattern.get_inputs()) == 1
add_pattern = pattern.get_inputs()[0]
assert add_pattern.type() == Add
add_node: Node = matched.get(add_pattern.name())
assert add_node is not None
assert not add_pattern.get_inputs()
# can not use add_node here
new_add1 = Add()
new_add1_node = Node.create_call_cell(new_add1, ['new_add_1'], add_node.get_args(), add_node.get_kwargs())
new_relu1 = ReLU()
new_relu1_node = Node.create_call_cell(new_relu1, ['new_relu_1'],
[ScopedValue.create_naming_value('new_add_1')])
new_relu2 = ReLU()
new_relu2_node = Node.create_call_cell(new_relu2, ['new_relu_2'],
[ScopedValue.create_naming_value('new_add_1')])
new_add2 = Add()
new_add2_node = Node.create_call_cell(new_add2, ['new_add_2'],
[ScopedValue.create_naming_value('new_relu_1'),
ScopedValue.create_naming_value('new_relu_2')])
return [new_add1_node, new_relu1_node, new_relu2_node, new_add2_node]
class AddReluPattern(PatternEngine):
def __init__(self):
super().__init__([Add, ReLU], AddReluReplacement())
net = TreeNetwork()
stree = SymbolTree(net)
conv1 = stree.get_node("conv1")
conv2 = stree.get_node("conv2")
add = stree.get_node("add")
relu = stree.get_node("relu")
relu1 = stree.get_node("relu1")
relu2 = stree.get_node("relu2")
assert conv1 is not None
assert conv2 is not None
assert add is not None
assert relu is not None
assert relu1 is not None
assert relu2 is not None
construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
assert len(construct_ast.body) == 8
assert len(stree.nodes()) == 9
add_relu_pattern = AddReluPattern()
add_relu_pattern.apply(stree)
assert len(construct_ast.body) == 10
assert len(stree.nodes()) == 11
conv1 = stree.get_node("conv1")
conv2 = stree.get_node("conv2")
add = stree.get_node("add")
relu = stree.get_node("relu")
relu1 = stree.get_node("relu1")
relu2 = stree.get_node("relu2")
new_add = stree.get_node("new_add")
new_relu = stree.get_node("new_relu")
new_relu_1 = stree.get_node("new_relu_1")
new_add_1 = stree.get_node("new_add_1")
assert conv1 is not None
assert conv2 is not None
assert add is None
assert relu is None
assert relu1 is not None
assert relu2 is not None
assert new_add is not None
assert new_relu is not None
assert new_relu_1 is not None
assert new_add_1 is not None
# check conv1 topological order
assert len(conv1.get_users()) == 1
assert conv1.get_users()[0] == new_add
# check conv2 topological order
assert len(conv2.get_users()) == 1
assert conv2.get_users()[0] == new_add
# check new_add topological order
assert len(new_add.get_inputs()) == 2
assert new_add.get_inputs()[0] == conv1
assert new_add.get_inputs()[1] == conv2
assert len(new_add.get_users()) == 2
assert new_add.get_users()[0] == new_relu
assert new_add.get_users()[1] == new_relu_1
# check new_relu topological order
assert len(new_relu.get_inputs()) == 1
assert new_relu.get_inputs()[0] == new_add
assert len(new_relu.get_users()) == 1
assert new_relu.get_users()[0] == new_add_1
# check new_relu_1 topological order
assert len(new_relu_1.get_inputs()) == 1
assert new_relu_1.get_inputs()[0] == new_add
assert len(new_relu_1.get_users()) == 1
assert new_relu_1.get_users()[0] == new_add_1
# check new_add_1 topological order
assert len(new_add_1.get_inputs()) == 2
assert new_add_1.get_inputs()[0] == new_relu_1
assert new_add_1.get_inputs()[1] == new_relu
assert len(new_add_1.get_users()) == 2
assert new_add_1.get_users()[0] == relu1
assert new_add_1.get_users()[1] == relu2
# check source code order
assert getattr(conv1.get_handler(), "_next") == conv2.get_handler()
assert getattr(conv2.get_handler(), "_next") == new_add.get_handler()
assert getattr(new_add.get_handler(), "_next") == new_relu.get_handler()
assert getattr(new_relu.get_handler(), "_next") == new_relu_1.get_handler()
assert getattr(new_relu_1.get_handler(), "_next") == new_add_1.get_handler()
assert getattr(new_add_1.get_handler(), "_next") == relu1.get_handler()
assert getattr(relu1.get_handler(), "_prev") == new_add_1.get_handler()
assert getattr(new_add_1.get_handler(), "_prev") == new_relu_1.get_handler()
assert getattr(new_relu_1.get_handler(), "_prev") == new_relu.get_handler()
assert getattr(new_relu.get_handler(), "_prev") == new_add.get_handler()
assert getattr(new_add.get_handler(), "_prev") == conv2.get_handler()
assert getattr(conv2.get_handler(), "_prev") == conv1.get_handler()
# check arg edge
assert len(conv1.get_targets()) == 1
assert len(conv2.get_targets()) == 1
assert len(new_add.get_args()) == 2
assert conv1.get_targets()[0] == new_add.get_args()[0]
assert conv2.get_targets()[0] == new_add.get_args()[1]
assert len(new_add.get_targets()) == 1
assert len(new_relu.get_args()) == 1
assert len(new_relu_1.get_args()) == 1
assert new_add.get_targets()[0] == new_relu.get_args()[0]
assert new_add.get_targets()[0] == new_relu_1.get_args()[0]
assert len(new_relu.get_targets()) == 1
assert len(new_relu_1.get_targets()) == 1
assert len(new_add_1.get_args()) == 2
assert new_relu.get_targets()[0] == new_add_1.get_args()[1]
assert new_relu_1.get_targets()[0] == new_add_1.get_args()[0]
assert len(new_add_1.get_targets()) == 1
assert len(relu1.get_args()) == 1
assert len(relu2.get_args()) == 1
assert new_add_1.get_targets()[0] == relu1.get_args()[0]
assert new_add_1.get_targets()[0] == relu2.get_args()[0]
class TreeNetwork2(Cell):
def __init__(self):
super().__init__()
self.conv1 = Conv2d(16, 16, 1)
self.conv2 = Conv2d(16, 16, 3)
self.add1 = AddN()
self.add2 = AddN()
self.relu = ReLU()
def construct(self, x, y, z):
x = self.conv1(x)
y = self.conv2(y)
z = self.add1(x, y, z)
z = self.add2(x, y, z)
z = self.relu(z)
return z
class MultiInputPattern(PatternEngine):
class MultiInputReplacement(Replacement):
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
assert not is_chain_pattern
assert pattern.type() == AddN
addn2_node: Node = matched.get(pattern.name())
assert addn2_node is not None
assert len(pattern.get_inputs()) == 3
conv1_pn = pattern.get_inputs()[0]
conv2_pn = pattern.get_inputs()[1]
addn1_pn = pattern.get_inputs()[2]
assert conv1_pn.type() == Conv2d
assert conv2_pn.type() == Conv2d
assert addn1_pn.type() == AddN
conv1_node: Node = matched.get(conv1_pn.name())
conv2_node: Node = matched.get(conv2_pn.name())
addn1_node: Node = matched.get(addn1_pn.name())
assert conv1_node is not None
assert conv2_node is not None
assert addn1_node is not None
assert len(conv1_node.get_inputs()) == 1
assert len(conv2_node.get_inputs()) == 1
assert len(addn1_node.get_inputs()) == 3
arg1 = conv1_node.get_args()[0]
arg2 = conv2_node.get_args()[0]
arg3 = addn1_node.get_args()[2]
# can not use add_node here
new_add1 = Add()
new_add1_node = Node.create_call_cell(new_add1, ['new_add1'], [arg1, arg2])
new_add2 = Add()
new_add2_node = Node.create_call_cell(new_add2, ['new_add2'], [ScopedValue.create_naming_value('new_add1'),
arg3])
return [new_add1_node, new_add2_node]
def __init__(self):
conv1_pn = PatternNode("conv1", Conv2d)
conv2_pn = PatternNode("conv2", Conv2d)
addn1_pn = PatternNode("addn1", AddN)
addn2_pn = PatternNode("addn2", AddN)
conv1_pn.set_inputs([VarNode()])
conv2_pn.set_inputs([VarNode()])
addn1_pn.set_inputs([conv1_pn, conv2_pn, VarNode()])
addn2_pn.set_inputs([conv1_pn, conv2_pn, addn1_pn])
super().__init__(addn2_pn, MultiInputPattern.MultiInputReplacement())
def test_multi_input_to_multi_pattern_tree_pattern():
"""
Feature: Python api PatternEngine.
Description: Construct a multi-to-multi PatternEngine and apply it on a SymbolTree, check SymbolTree after
PatternEngine applied.
Expectation: Success.
"""
net = TreeNetwork2()
stree = SymbolTree(net)
conv1 = stree.get_node("conv1")
conv2 = stree.get_node("conv2")
add1 = stree.get_node("add1")
add2 = stree.get_node("add2")
relu = stree.get_node("relu")
assert conv1 is not None
assert conv2 is not None
assert add1 is not None
assert add2 is not None
assert relu is not None
construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
assert len(construct_ast.body) == 6
assert len(stree.nodes()) == 9
multi_input_pattern = MultiInputPattern()
multi_input_pattern.apply(stree)
assert len(construct_ast.body) == 4
assert len(stree.nodes()) == 7
conv1 = stree.get_node("conv1")
conv2 = stree.get_node("conv2")
add1 = stree.get_node("add1")
add2 = stree.get_node("add2")
relu = stree.get_node("relu")
new_add1 = stree.get_node("new_add1")
new_add2 = stree.get_node("new_add2")
inputx = stree.get_node("input_x")
inputy = stree.get_node("input_y")
inputz = stree.get_node("input_z")
assert conv1 is None
assert conv2 is None
assert add1 is None
assert add2 is None
assert relu is not None
assert new_add1 is not None
assert new_add2 is not None
assert inputx is not None
assert inputy is not None
assert inputz is not None
# check inputx topological order
assert len(inputx.get_users()) == 1
assert inputx.get_users()[0] == new_add1
# check inputy topological order
assert len(inputy.get_users()) == 1
assert inputy.get_users()[0] == new_add1
# check inputz topological order
assert len(inputz.get_users()) == 1
assert inputz.get_users()[0] == new_add2
# check new_add1 topological order
assert len(new_add1.get_inputs()) == 2
assert new_add1.get_inputs()[0] == inputx
assert new_add1.get_inputs()[1] == inputy
assert len(new_add1.get_users()) == 1
assert new_add1.get_users()[0] == new_add2
# check new_add2 topological order
assert len(new_add2.get_inputs()) == 2
assert new_add2.get_inputs()[0] == new_add1
assert new_add2.get_inputs()[1] == inputz
assert len(new_add2.get_users()) == 1
assert new_add2.get_users()[0] == relu
# check relu topological order
assert len(relu.get_inputs()) == 1
assert relu.get_inputs()[0] == new_add2
# check source code order
assert getattr(inputz.get_handler(), "_next") == new_add1.get_handler()
assert getattr(new_add1.get_handler(), "_next") == new_add2.get_handler()
assert getattr(new_add2.get_handler(), "_next") == relu.get_handler()
assert getattr(relu.get_handler(), "_prev") == new_add2.get_handler()
assert getattr(new_add2.get_handler(), "_prev") == new_add1.get_handler()
assert getattr(new_add1.get_handler(), "_prev") == inputz.get_handler()
# check arg edge
assert len(inputx.get_targets()) == 1
assert len(inputy.get_targets()) == 1
assert len(new_add1.get_args()) == 2
assert inputx.get_targets()[0] == new_add1.get_args()[0]
assert inputy.get_targets()[0] == new_add1.get_args()[1]
assert len(inputz.get_targets()) == 1
assert len(new_add1.get_targets()) == 1
assert len(new_add2.get_args()) == 2
assert new_add1.get_targets()[0] == new_add2.get_args()[0]
assert inputz.get_targets()[0] == new_add2.get_args()[1]
assert len(new_add2.get_targets()) == 1
assert len(relu.get_args()) == 1
assert new_add2.get_targets()[0] == relu.get_args()[0]
class TreeNetwork3(Cell):
def __init__(self):
super().__init__()
self.conv1 = Conv2d(16, 16, 1)
self.conv2 = Conv2d(16, 16, 3)
self.add1 = AddN()
self.add2 = AddN()
self.relu = ReLU()
def construct(self, x):
y = self.conv1(x)
z = self.conv2(x)
x = self.add1(y, z, x)
x = self.add2(y, z, x)
x = self.relu(x)
return x
def test_one_input_to_multi_pattern_tree_pattern():
"""
Feature: Python api PatternEngine.
Description: Construct a multi-to-multi PatternEngine and apply it on a SymbolTree, check SymbolTree after
PatternEngine applied.
Expectation: Success.
"""
net = TreeNetwork3()
stree = SymbolTree(net)
conv1 = stree.get_node("conv1")
conv2 = stree.get_node("conv2")
add1 = stree.get_node("add1")
add2 = stree.get_node("add2")
relu = stree.get_node("relu")
assert conv1 is not None
assert conv2 is not None
assert add1 is not None
assert add2 is not None
assert relu is not None
construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
assert len(construct_ast.body) == 6
assert len(stree.nodes()) == 7
multi_input_pattern = MultiInputPattern()
multi_input_pattern.apply(stree)
assert len(construct_ast.body) == 4
assert len(stree.nodes()) == 5
conv1 = stree.get_node("conv1")
conv2 = stree.get_node("conv2")
add1 = stree.get_node("add1")
add2 = stree.get_node("add2")
relu = stree.get_node("relu")
new_add1 = stree.get_node("new_add1")
new_add2 = stree.get_node("new_add2")
inputx = stree.get_node("input_x")
assert conv1 is None
assert conv2 is None
assert add1 is None
assert add2 is None
assert relu is not None
assert new_add1 is not None
assert new_add2 is not None
assert inputx is not None
# check inputx topological order
assert len(inputx.get_users()) == 2
assert inputx.get_users()[0] == new_add1
assert inputx.get_users()[1] == new_add2
# check new_add1 topological order
assert len(new_add1.get_inputs()) == 2
assert new_add1.get_inputs()[0] == inputx
assert new_add1.get_inputs()[1] == inputx
assert len(new_add1.get_users()) == 1
assert new_add1.get_users()[0] == new_add2
# check new_add2 topological order
assert len(new_add2.get_inputs()) == 2
assert new_add2.get_inputs()[0] == new_add1
assert new_add2.get_inputs()[1] == inputx
assert len(new_add2.get_users()) == 1
assert new_add2.get_users()[0] == relu
# check relu topological order
assert len(relu.get_inputs()) == 1
assert relu.get_inputs()[0] == new_add2
# check source code order
assert getattr(inputx.get_handler(), "_next") == new_add1.get_handler()
assert getattr(new_add1.get_handler(), "_next") == new_add2.get_handler()
assert getattr(new_add2.get_handler(), "_next") == relu.get_handler()
assert getattr(relu.get_handler(), "_prev") == new_add2.get_handler()
assert getattr(new_add2.get_handler(), "_prev") == new_add1.get_handler()
assert getattr(new_add1.get_handler(), "_prev") == inputx.get_handler()
# check arg edge
assert len(inputx.get_targets()) == 1
assert len(new_add1.get_args()) == 2
assert inputx.get_targets()[0] == new_add1.get_args()[0]
assert inputx.get_targets()[0] == new_add1.get_args()[1]
assert len(inputx.get_targets()) == 1
assert len(new_add1.get_targets()) == 1
assert len(new_add2.get_args()) == 2
assert new_add1.get_targets()[0] == new_add2.get_args()[0]
assert inputx.get_targets()[0] == new_add2.get_args()[1]
assert len(new_add2.get_targets()) == 1
assert len(relu.get_args()) == 1
assert new_add2.get_targets()[0] == relu.get_args()[0]

View File

@ -0,0 +1,388 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import ast
import inspect
from mindspore.nn import Cell, Conv2d, BatchNorm2d, ReLU
from mindspore.ops import Add
from mindspore.rewrite import ScopedValue, ValueType, NodeType
from mindspore.rewrite import Node as NodeApi
from mindspore.rewrite.symbol_tree import SymbolTree
from mindspore.rewrite.node import Node
class Network(Cell):
def __init__(self):
super().__init__()
self.conv = Conv2d(16, 16, 3)
self.bn = BatchNorm2d(16)
self.relu1 = ReLU()
self.relu2 = ReLU()
self.relu3 = ReLU()
def construct(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu1(x)
x = self.relu2(x)
x = self.relu3(x)
return x
def create_symbol_tree():
net = Network()
source = inspect.getsource(type(net))
ast_root = ast.parse(source)
ast_module = ast_root
assert isinstance(ast_root, ast.Module)
ast_class = ast_module.body[0]
assert isinstance(ast_class, ast.ClassDef)
ast_init_func = ast_class.body[0]
assert isinstance(ast_init_func, ast.FunctionDef)
ast_construct_func = ast_class.body[1]
assert isinstance(ast_construct_func, ast.FunctionDef)
ast_conv = ast_construct_func.body[0]
ast_bn = ast_construct_func.body[1]
ast_relu1 = ast_construct_func.body[2]
ast_relu2 = ast_construct_func.body[3]
ast_relu3 = ast_construct_func.body[4]
ast_return = ast_construct_func.body[5]
stree = SymbolTree(net, ast_module)
stree.set_class_ast(ast_class)
stree.set_init_func_ast(ast_init_func)
stree.set_ast_root(ast_construct_func)
stree.append_input_node("x")
conv_node = Node.create_call_cell(net.conv, ast_conv, [ScopedValue.create_naming_value("x")],
ScopedValue.create_naming_value("conv", "self"),
[ScopedValue.create_naming_value("x")],
{},
"conv")
stree.append_origin_field(conv_node)
bn_node = Node.create_call_cell(net.bn, ast_bn, [ScopedValue.create_naming_value("x")],
ScopedValue.create_naming_value("bn", "self"),
[ScopedValue.create_naming_value("x")], {},
"bn")
bn_node = stree.append_origin_field(bn_node)
relu1_node = Node.create_call_cell(net.relu1, ast_relu1, [ScopedValue.create_naming_value("x")],
ScopedValue.create_naming_value("relu1", "self"),
[ScopedValue.create_naming_value("x")],
{}, "relu1")
relu1_node = stree.append_origin_field(relu1_node)
relu2_node = Node.create_call_cell(net.relu2, ast_relu2, [ScopedValue.create_naming_value("x")],
ScopedValue.create_naming_value("relu2", "self"),
[ScopedValue.create_naming_value("x")],
{}, "relu2")
relu2_node = stree.append_origin_field(relu2_node)
relu3_node = Node.create_call_cell(net.relu3, ast_relu3, [ScopedValue.create_naming_value("x")],
ScopedValue.create_naming_value("relu3", "self"),
[ScopedValue.create_naming_value("x")],
{}, "relu3")
stree.append_origin_field(relu3_node)
node_return = Node.create_output_node(ast_return, ["x"])
stree.append_origin_field(node_return)
return stree, bn_node, relu1_node, relu2_node
def test_insert_node():
"""
Feature: Python api insert_node of SymbolTree of Rewrite.
Description: Call insert_node to insert a node into SymbolTree.
Expectation: Success.
"""
stree, _, relu1, relu2 = create_symbol_tree()
construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
providers = getattr(getattr(stree, "_topo_mgr"), "_target_provider")
consumers = getattr(getattr(stree, "_topo_mgr"), "_target_consumer")
providers_len = len(providers)
consumers_len = len(consumers)
assert len(stree.nodes()) == 7
assert len(construct_ast.body) == 6
assert len(relu1.get_targets()) == 1
assert len(relu2.get_normalized_args().values()) == 1
assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
input1 = 1
node = Node.create_call_cell(Add(), None, ['x'], 'new_conv',
[ScopedValue.create_naming_value('x'), ScopedValue.create_variable_value(input1)], {},
'new_conv')
position = stree.before(relu2)
node = stree.insert_node(position, node)
# check nodes size
assert len(stree.nodes()) == 8
# check args
assert len(relu2.get_normalized_args().values()) == 1
assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
assert len(node.get_normalized_args().values()) == 2
assert list(node.get_normalized_args().values())[0] == ScopedValue.create_naming_value('x')
assert list(node.get_normalized_args().values())[1].type == ValueType.IntValue
# check provider
assert len(providers) == providers_len + 1
assert len(node.get_targets()) == 1
assert providers.get(node.get_targets()[0])[0] == node
assert providers.get(node.get_targets()[0])[1] == 0
# check consumer
assert len(consumers) == consumers_len + 1
assert consumers.get(list(node.get_normalized_args().values())[1]) is not None
# check inputs
assert len(relu2.get_inputs()) == 1
assert relu2.get_inputs()[0] == relu1
assert len(node.get_inputs()) == 1
assert node.get_inputs()[0].get_node_type() == NodeType.Input
# check ast
node_ast = node.get_ast()
assert isinstance(node_ast, ast.Assign)
args = node_ast.value.args
assert isinstance(args, list)
assert len(args) == 2
assert isinstance(args[0], ast.Name)
assert isinstance(args[1], ast.Constant)
assert len(construct_ast.body) == 7
def test_set_node_arg():
"""
Feature: Python api set_node_arg of SymbolTree of Rewrite.
Description: Call set_node_arg to change topological-order of a node.
Expectation: Success.
"""
stree, bn, relu1, relu2 = create_symbol_tree()
assert len(stree.nodes()) == 7
assert len(bn.get_targets()) == 1
bn_output = bn.get_targets()[0]
# check bn topological order
assert len(stree.get_node_users(bn)) == 1
assert stree.get_node_users(bn)[0][0] == relu1
# check relu1 topological order
assert len(stree.get_node_inputs(relu1)) == 1
assert stree.get_node_inputs(relu1)[0] == bn
assert len(stree.get_node_users(relu1)) == 1
assert stree.get_node_users(relu1)[0][0] == relu2
# check relu2 topological order
assert len(stree.get_node_inputs(relu2)) == 1
assert stree.get_node_inputs(relu2)[0] == relu1
# check relu1 and relu2 edge
assert len(relu1.get_targets()) == 1
assert len(relu2.get_normalized_args().values()) == 1
assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
stree.set_node_arg(relu2, 0, bn_output)
# check bn topological order
assert len(stree.get_node_users(bn)) == 2
assert stree.get_node_users(bn)[0][0] == relu1
assert stree.get_node_users(bn)[1][0] == relu2
# check relu1 topological order
assert len(stree.get_node_inputs(relu1)) == 1
assert stree.get_node_inputs(relu1)[0] == bn
assert len(stree.get_node_users(relu1)) == 0
# check relu2 topological order
assert len(stree.get_node_inputs(relu2)) == 1
assert stree.get_node_inputs(relu2)[0] == bn
# check bn and relu2 edge
assert len(relu1.get_targets()) == 1
assert len(relu2.get_normalized_args().values()) == 1
assert bn_output == list(relu2.get_normalized_args().values())[0]
# check ast
node_ast = relu2.get_ast()
assert isinstance(node_ast, ast.Assign)
args = node_ast.value.args
assert isinstance(args, list)
assert len(args) == 1
assert isinstance(args[0], ast.Name)
assert args[0].id == bn_output.value
def test_set_node_arg_by_node():
"""
Feature: Python api set_node_arg_by_node of SymbolTree of Rewrite.
Description: Call set_node_arg_by_node to change topological-order of a node.
Expectation: Success.
"""
stree, bn, relu1, relu2 = create_symbol_tree()
assert len(stree.nodes()) == 7
assert len(bn.get_targets()) == 1
bn_output = bn.get_targets()[0]
# check bn topological order
assert len(stree.get_node_users(bn)) == 1
assert stree.get_node_users(bn)[0][0] == relu1
# check relu1 topological order
assert len(stree.get_node_inputs(relu1)) == 1
assert stree.get_node_inputs(relu1)[0] == bn
assert len(stree.get_node_users(relu1)) == 1
assert stree.get_node_users(relu1)[0][0] == relu2
# check relu2 topological order
assert len(stree.get_node_inputs(relu2)) == 1
assert stree.get_node_inputs(relu2)[0] == relu1
# check relu1 and relu2 edge
assert len(relu1.get_targets()) == 1
assert len(relu2.get_normalized_args().values()) == 1
assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
stree.set_node_arg_by_node(relu2, 0, bn)
# check bn topological order
assert len(stree.get_node_users(bn)) == 2
assert stree.get_node_users(bn)[0][0] == relu1
assert stree.get_node_users(bn)[1][0] == relu2
# check relu1 topological order
assert len(stree.get_node_inputs(relu1)) == 1
assert stree.get_node_inputs(relu1)[0] == bn
assert len(stree.get_node_users(relu1)) == 0
# check relu2 topological order
assert len(stree.get_node_inputs(relu2)) == 1
assert stree.get_node_inputs(relu2)[0] == bn
# check bn and relu2 edge
assert len(relu1.get_targets()) == 1
assert len(relu2.get_normalized_args().values()) == 1
assert bn_output == list(relu2.get_normalized_args().values())[0]
# check ast
node_ast = relu2.get_ast()
assert isinstance(node_ast, ast.Assign)
args = node_ast.value.args
assert isinstance(args, list)
assert len(args) == 1
assert isinstance(args[0], ast.Name)
assert args[0].id == bn_output.value
def test_erase_succeed():
"""
Feature: Python api erase_node of SymbolTree of Rewrite.
Description: Call erase_node to erase a node from SymbolTree.
Expectation: Success.
"""
stree, bn, relu1, relu2 = create_symbol_tree()
construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
providers = getattr(getattr(stree, "_topo_mgr"), "_target_provider")
providers_len = len(providers)
assert len(stree.nodes()) == 7
assert len(construct_ast.body) == 6
stree.set_node_arg_by_node(relu2, 0, bn)
stree.erase_node(relu1)
assert len(stree.nodes()) == 6
assert len(providers) == providers_len - 1
assert len(construct_ast.body) == 5
def test_erase_failed():
"""
Feature: Python api erase_node of SymbolTree of Rewrite.
Description: Call erase_node to erase a node from SymbolTree which is not isolated.
Expectation: Failure.
"""
stree, _, relu1, _ = create_symbol_tree()
catched_error = False
try:
stree.erase_node(relu1)
except RuntimeError:
catched_error = True
assert catched_error
def test_replace_one_to_one():
"""
Feature: Python api replace of SymbolTree of Rewrite.
Description: Call replace to replace an origin node to a new node.
Expectation: Success.
"""
stree, bn, relu1, relu2 = create_symbol_tree()
construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
assert len(construct_ast.body) == 6
assert len(stree.nodes()) == 7
new_conv = Conv2d(16, 16, 5)
new_conv_node = NodeApi.create_call_cell(new_conv, [ScopedValue.create_naming_value("new_conv")],
bn.get_targets()).get_handler()
new_conv_node = stree.replace(relu1, [new_conv_node])
assert len(stree.nodes()) == 7
# check ast
assert len(construct_ast.body) == 6
node_ast: ast.Assign = construct_ast.body[2]
func_ast: ast.Attribute = node_ast.value.func
assert func_ast.attr == new_conv_node.get_name()
# check bn topological order
assert len(stree.get_node_users(bn)) == 1
assert stree.get_node_users(bn)[0][0] == new_conv_node
# check new_conv_node topological order
assert len(stree.get_node_inputs(new_conv_node)) == 1
assert stree.get_node_inputs(new_conv_node)[0] == bn
assert len(stree.get_node_users(new_conv_node)) == 1
assert stree.get_node_users(new_conv_node)[0][0] == relu2
# check relu2 topological order
assert len(stree.get_node_inputs(relu2)) == 1
assert stree.get_node_inputs(relu2)[0] == new_conv_node
# check arg edge
assert len(bn.get_targets()) == 1
assert len(new_conv_node.get_normalized_args().values()) == 1
assert bn.get_targets()[0] == list(new_conv_node.get_normalized_args().values())[0]
assert len(new_conv_node.get_targets()) == 1
assert len(relu2.get_normalized_args().values()) == 1
assert new_conv_node.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
def test_replace_one_to_multi():
"""
Feature: Python api replace of SymbolTree of Rewrite.
Description: Call replace to replace an origin node to a new node-tree.
Expectation: Success.
"""
stree, bn, relu1, relu2 = create_symbol_tree()
construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
assert len(construct_ast.body) == 6
assert len(stree.nodes()) == 7
new_conv_node = NodeApi.create_call_cell(Conv2d(16, 16, 5), [ScopedValue.create_naming_value("new_conv")],
bn.get_targets()).get_handler()
new_relu_node = NodeApi.create_call_cell(ReLU(), [ScopedValue.create_naming_value("new_relu")],
new_conv_node.get_targets()).get_handler()
new_relu_node = stree.replace(relu1, [new_relu_node, new_conv_node])
new_conv_node = new_relu_node.get_inputs()[0]
assert len(stree.nodes()) == 8
# check ast
assert len(construct_ast.body) == 7
new_conv_ast: ast.Assign = construct_ast.body[2]
new_conv_func_ast: ast.Attribute = new_conv_ast.value.func
assert new_conv_func_ast.attr == new_conv_node.get_name()
new_relu_ast: ast.Assign = construct_ast.body[3]
new_relu_func_ast: ast.Attribute = new_relu_ast.value.func
assert new_relu_func_ast.attr == new_relu_node.get_name()
# check bn topological order
assert len(stree.get_node_users(bn)) == 1
assert stree.get_node_users(bn)[0][0] == new_conv_node
# check new_conv_node topological order
assert len(stree.get_node_inputs(new_conv_node)) == 1
assert stree.get_node_inputs(new_conv_node)[0] == bn
assert len(stree.get_node_users(new_conv_node)) == 1
assert stree.get_node_users(new_conv_node)[0][0] == new_relu_node
# check new_relu_node topological order
assert len(stree.get_node_inputs(new_relu_node)) == 1
assert stree.get_node_inputs(new_relu_node)[0] == new_conv_node
assert len(stree.get_node_users(new_relu_node)) == 1
assert stree.get_node_users(new_relu_node)[0][0] == relu2
# check relu2 topological order
assert len(stree.get_node_inputs(relu2)) == 1
assert stree.get_node_inputs(relu2)[0] == new_relu_node
# check arg edge
assert len(bn.get_targets()) == 1
assert len(new_conv_node.get_normalized_args().values()) == 1
assert bn.get_targets()[0] == list(new_conv_node.get_normalized_args().values())[0]
assert len(new_conv_node.get_targets()) == 1
assert len(new_relu_node.get_normalized_args().values()) == 1
assert new_conv_node.get_targets()[0] == list(new_relu_node.get_normalized_args().values())[0]
assert len(new_relu_node.get_targets()) == 1
assert len(relu2.get_normalized_args().values()) == 1
assert new_relu_node.get_targets()[0] == list(relu2.get_normalized_args().values())[0]

View File

@ -119,6 +119,12 @@ else
if [ ${RET} -ne 0 ]; then
exit ${RET}
fi
pytest -s $CURRPATH/rewrite/*.py
RET=$?
if [ ${RET} -ne 0 ]; then
exit ${RET}
fi
fi
RET=$?