forked from mindspore-Ecosystem/mindspore
add mindpore rewrite
This commit is contained in:
parent
65c881a246
commit
de196f3a25
|
@ -38,6 +38,11 @@
|
|||
"mindspore/mindspore/python/mindspore/train/serialization.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/train/model.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/log.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/rewrite/api/node.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/rewrite/node.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/rewrite/parser_register.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/rewrite/api/pattern_engine.py" "protected-access"
|
||||
"mindspore/mindspore/python/mindspore/rewrite/symbol_tree.py" "inconsistent-return-statements"
|
||||
"mindspore/model_zoo/official/cv" "missing-docstring"
|
||||
"mindspore/model_zoo/official/cv" "c-extension-no-member"
|
||||
"mindspore/model_zoo/official/nlp/bert_thor/src/bert_model.py" "redefined-outer-name"
|
||||
|
@ -101,6 +106,10 @@
|
|||
"mindspore/tests/ut/python/pynative_mode" "unused-variable"
|
||||
"mindspore/tests/ut/python/pynative_mode/test_stop_gradient.py" "redefined-outer-name"
|
||||
"mindspore/tests/ut/python/pynative_mode/test_stop_gradient.py" "super-init-not-called"
|
||||
"mindspore/tests/ut/python/rewrite/test_flatten_recursive_stmt.py" "consider-using-ternary"
|
||||
"mindspore/tests/ut/python/rewrite/test_node.py" "syntax-error"
|
||||
"mindspore/tests/ut/python/rewrite/test_node.py" "protected-access"
|
||||
"mindspore/tests/ut/python/rewrite/test_symbol_tree.py" "len-as-condition"
|
||||
"mindspore/tests/ut/python/test_log.py" "possibly-unused-variable"
|
||||
"mindspore/tests/ut/python/test_log.py" "protected-access"
|
||||
"mindspore/tests/ut/python/train/summary/test_summary_collector.py" "protected-access"
|
||||
|
|
|
@ -306,6 +306,7 @@ install(
|
|||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/communication
|
||||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/profiler
|
||||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/compression
|
||||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/rewrite
|
||||
${CMAKE_SOURCE_DIR}/mindspore/python/mindspore/run_check
|
||||
DESTINATION ${INSTALL_PY_DIR}
|
||||
COMPONENT mindspore
|
||||
|
|
|
@ -26,6 +26,7 @@ from .context import GRAPH_MODE, PYNATIVE_MODE, set_context, get_context, set_au
|
|||
get_auto_parallel_context, reset_auto_parallel_context, ParallelMode, set_ps_context, \
|
||||
get_ps_context, reset_ps_context, set_fl_context, get_fl_context
|
||||
from .version import __version__
|
||||
from .rewrite import *
|
||||
|
||||
|
||||
__all__ = ["run_check"]
|
||||
|
@ -34,3 +35,4 @@ __all__.extend(common.__all__)
|
|||
__all__.extend(train.__all__)
|
||||
__all__.extend(log.__all__)
|
||||
__all__.extend(context.__all__)
|
||||
__all__.extend(rewrite.__all__)
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
MindSpore Rewrite module.
|
||||
This is an experimental python package that is subject to change or deletion.
|
||||
"""
|
||||
from .parsers.module_parser import g_module_parser
|
||||
from .parsers.class_def_parser import g_classdef_parser
|
||||
from .parsers.function_def_parser import g_functiondef_parser
|
||||
from .parsers.arguments_parser import g_arguments_parser
|
||||
from .parsers.assign_parser import g_assign_parser
|
||||
from .parsers.return_parser import g_return_parser
|
||||
from .api.scoped_value import ScopedValue, ValueType
|
||||
from .api.symbol_tree import SymbolTree
|
||||
from .api.node import Node
|
||||
from .api.node_type import NodeType
|
||||
from .api.pattern_engine import PatternEngine, PatternNode, VarNode, Replacement
|
||||
|
||||
__all__ = ["SymbolTree", "Node", "NodeType", "ScopedValue", "ValueType", "PatternEngine", "PatternNode", "VarNode",
|
||||
"Replacement"]
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
MindSpore Rewrite api.
|
||||
"""
|
|
@ -0,0 +1,302 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Rewrite module api: Node."""
|
||||
|
||||
from typing import Union, Optional
|
||||
|
||||
from mindspore.nn import Cell
|
||||
from ..node import Node as NodeImpl
|
||||
from ..symbol_tree import SymbolTree as SymbolTreeImpl
|
||||
from .node_type import NodeType
|
||||
from .scoped_value import ScopedValue
|
||||
|
||||
|
||||
class Node:
|
||||
"""
|
||||
Node is a data structure represents a source code line in network.
|
||||
|
||||
For the most part, Node represents an operator invoking in forward which could be an instance of `Cell`, an instance
|
||||
of `Primitive` or a callable method.
|
||||
|
||||
`NodeImpl` mentioned below is implementation of `Node` which is not an interface of Rewrite. Rewrite recommend
|
||||
invoking specific create method of `Node` to instantiate an instance of Node such as `create_call_cell` rather than
|
||||
invoking constructor of `Node` directly, so don't care about what is `NodeImpl` and use its instance just as a
|
||||
handler.
|
||||
|
||||
Args:
|
||||
node (NodeImpl): A handler of `NodeImpl`.
|
||||
"""
|
||||
|
||||
def __init__(self, node: NodeImpl):
|
||||
self._node = node
|
||||
|
||||
def get_handler(self) -> NodeImpl:
|
||||
"""
|
||||
Get handler of node implementation.
|
||||
|
||||
Returns:
|
||||
An instance of `NodeImpl`.
|
||||
"""
|
||||
return self._node
|
||||
|
||||
@staticmethod
|
||||
def create_call_cell(cell: Cell, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None,
|
||||
kwargs: {str: ScopedValue}=None, name: str = "") -> 'Node':
|
||||
"""
|
||||
Create a node. Only support create from a `Cell` now.
|
||||
|
||||
A node is corresponding to source code like:
|
||||
|
||||
.. code-block::
|
||||
|
||||
`targets` = self.`name`(*`args`, **`kwargs`)
|
||||
|
||||
Args:
|
||||
cell (Cell): Cell-operator of this forward-layer.
|
||||
targets (list[ScopedValue]): Indicate output names. Used as targets of an assign statement in source code.
|
||||
Rewrite will check and ensure the uniqueness of each target while node being inserted.
|
||||
args (list[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in
|
||||
source code. Default is None indicate the `cell` has no args inputs. Rewrite will check and ensure the
|
||||
uniqueness of each arg while node being inserted.
|
||||
kwargs (dict{str: ScopedValue}): Indicate keyword input names. Used as kwargs of a call expression of an
|
||||
assign statement in source code. Default is None indicate the `cell` has no kwargs inputs. Rewrite will
|
||||
check and ensure the uniqueness of each kwarg while node being inserted.
|
||||
name (str): Indicate the name of node. Used as field name in source code. Default is None. Rewrite will
|
||||
generate name from `targets` when name is None. Rewrite will check and ensure the uniqueness of `name`
|
||||
while node being inserted.
|
||||
|
||||
Returns:
|
||||
An instance of `Node`.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `cell` is not a `Cell`.
|
||||
RuntimeError: If `targets` is None.
|
||||
RuntimeError: If target in `targets` is not a `NamingValue`-`ScopedValue`.
|
||||
RuntimeError: If arg in `args` is not a `NamingValue`-`ScopedValue` or a `CustomObjValue`-`ScopedValue`.
|
||||
RuntimeError: If value of kwarg in `kwargs` is not a `NamingValue`-`ScopedValue` or a
|
||||
`CustomObjValue`-`ScopedValue`.
|
||||
"""
|
||||
return Node(NodeImpl.create_call_cell(cell, None, targets, ScopedValue.create_naming_value(name, "self"), args,
|
||||
kwargs, name))
|
||||
|
||||
def get_prev(self) -> 'Node':
|
||||
"""
|
||||
Get previous node of current node in source code order.
|
||||
|
||||
Returns:
|
||||
An instance of `Node` as previous node.
|
||||
"""
|
||||
return Node(self._node.get_prev())
|
||||
|
||||
def get_next(self) -> 'Node':
|
||||
"""
|
||||
Get next node of current node in source code order.
|
||||
|
||||
Returns:
|
||||
An instance of `Node` as next node.
|
||||
"""
|
||||
return Node(self._node.get_next())
|
||||
|
||||
def get_inputs(self) -> ['Node']:
|
||||
"""
|
||||
Get input nodes of current node in topological order.
|
||||
|
||||
Returns:
|
||||
A list of instances of `Node` as input nodes.
|
||||
"""
|
||||
return [Node(node_impl) for node_impl in self._node.get_inputs()]
|
||||
|
||||
def get_users(self) -> ['Node']:
|
||||
"""
|
||||
Get output nodes of current node in topological order.
|
||||
|
||||
Returns:
|
||||
A list of nodes represents users.
|
||||
"""
|
||||
belong_symbol_tree: SymbolTreeImpl = self._node.get_belong_symbol_tree()
|
||||
if belong_symbol_tree is None:
|
||||
return []
|
||||
unique_results = []
|
||||
for node_user in belong_symbol_tree.get_node_users(self._node):
|
||||
node = node_user[0]
|
||||
if node not in unique_results:
|
||||
unique_results.append(node)
|
||||
return [Node(node_impl) for node_impl in unique_results]
|
||||
|
||||
def set_arg(self, index: int, arg: Union[ScopedValue, str]):
|
||||
"""
|
||||
Set argument of current node.
|
||||
|
||||
Args:
|
||||
index (int): Indicate which input being modified.
|
||||
arg (Union[ScopedValue, str]): New argument to been set.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `index` is out of range.
|
||||
RuntimeError: If `arg` a `NamingValue`-`ScopedValue` or a `CustomObjValue`-`ScopedValue` when `arg` is an
|
||||
`ScopedValue`.
|
||||
"""
|
||||
belong_symbol_tree: SymbolTreeImpl = self._node.get_belong_symbol_tree()
|
||||
if belong_symbol_tree is None:
|
||||
self._node.set_arg(arg, index)
|
||||
else:
|
||||
belong_symbol_tree.set_node_arg(self._node, index, arg)
|
||||
|
||||
def set_arg_by_node(self, arg_idx: int, src_node: 'Node', out_idx: Optional[int] = None):
|
||||
"""
|
||||
Set argument of current node by another `Node`.
|
||||
|
||||
Args:
|
||||
arg_idx (int): Indicate which input being modified.
|
||||
src_node (Node): A `Node` as new input. Can be a node or name of node.
|
||||
out_idx (int, optional): Indicate which output of `src_node` as new input of current node. Default is None
|
||||
which means use first output of `src_node` as new input.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `src_node` is not belong to current `SymbolTree`.
|
||||
RuntimeError: If current node and `src_node` is not belong to same `SymbolTree`.
|
||||
RuntimeError: If `arg_idx` is out of range.
|
||||
RuntimeError: If `out_idx` is out of range.
|
||||
RuntimeError: If `src_node` has multi-outputs while `out_idx` is None or `out_idx` is not offered.
|
||||
"""
|
||||
belong_symbol_tree: SymbolTreeImpl = self._node.get_belong_symbol_tree()
|
||||
if belong_symbol_tree is None:
|
||||
self._node.set_arg_by_node(arg_idx, src_node._node, out_idx)
|
||||
else:
|
||||
belong_symbol_tree.set_node_arg_by_node(self._node, arg_idx, src_node._node, out_idx)
|
||||
|
||||
def get_targets(self) -> [ScopedValue]:
|
||||
"""
|
||||
Get targets of current node.
|
||||
|
||||
- When node_type of current node is `CallCell`, `CallPrimitive`, `CallMethod` or `Tree`, `targets` are strings
|
||||
represents invoke result of the cell-op or primitive-op or function-call which are corresponding to targets of
|
||||
ast.Assign.
|
||||
- When node_type of current node is Input, `targets` should have only one element which is a string represents
|
||||
parameter of function.
|
||||
- When node_type of current node is `Python` or `Output`, `targets` are don't-care.
|
||||
|
||||
Returns:
|
||||
A list of instances of ScopedValue as targets of node.
|
||||
"""
|
||||
return self._node.get_targets()
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""
|
||||
Get the name of current node.
|
||||
|
||||
When node has been inserted into `SymbolTree`, the name of node should be unique in `SymbolTree`.
|
||||
|
||||
Returns:
|
||||
A string as name of node.
|
||||
"""
|
||||
return self._node.get_name()
|
||||
|
||||
def get_node_type(self) -> NodeType:
|
||||
"""
|
||||
Get the node_type of current node.
|
||||
|
||||
Returns:
|
||||
A NodeType as node_type of node.
|
||||
"""
|
||||
return self._node.get_node_type()
|
||||
|
||||
def get_instance_type(self) -> type:
|
||||
"""
|
||||
Get the instance_type of current node.
|
||||
|
||||
- When node_type of current node is `CallCell`, instance_type is type of cell-op.
|
||||
- When node_type of current node is `CallPrimitive`, instance_type is type of primitive-op.
|
||||
- When node_type of current node is `Tree`, instance_type is type of network-cell.
|
||||
- When node_type of current node is `Python`, `Input`, `Output` or `CallMethod`, instance_type should be
|
||||
NoneType.
|
||||
|
||||
Returns:
|
||||
A type object represents corresponding instance type of current node.
|
||||
"""
|
||||
return self._node.get_instance_type()
|
||||
|
||||
def get_instance(self):
|
||||
"""
|
||||
Get the instance of current node.
|
||||
|
||||
- When node_type of current node is `CallCell`, instance is an instance of Cell.
|
||||
- When node_type of current node is `CallPrimitive`, instance is an instance of primitive.
|
||||
- When node_type of current node is `Tree`, instance is an instance of network-cell.
|
||||
- When node_type of current node is `Python`, `Input`, `Output` or `CallMethod`, instance should be None.
|
||||
|
||||
Returns:
|
||||
A object represents corresponding instance of current node.
|
||||
"""
|
||||
return self._node.get_instance()
|
||||
|
||||
def get_args(self) -> [ScopedValue]:
|
||||
"""
|
||||
Get the arguments of current node.
|
||||
|
||||
- When node_type of current node is `CallCell`, `CallPrimitive` or `Tree`, arguments are corresponding to args
|
||||
of ast.Call which represents arguments to invoke forward method of cell-op or primitive-op.
|
||||
- When node_type of current node is `Input`, arguments represents default-value of argument of function.
|
||||
- When node_type of current node is `Output`, arguments represents return values.
|
||||
- When node_type of current node is `Python`, arguments are don't-care.
|
||||
|
||||
Returns:
|
||||
A list of instances of `ScopedValue`.
|
||||
"""
|
||||
return self._node.get_args()
|
||||
|
||||
def get_kwargs(self) -> {str: ScopedValue}:
|
||||
"""
|
||||
Get the keyword arguments of current node.
|
||||
|
||||
- When node_type of current node is `CallCell`, `CallPrimitive` or `Tree`, keyword arguments are corresponding
|
||||
to kwargs of ast.Call which represents arguments to invoke forward method of cell-op or primitive-op.
|
||||
- When node_type of current node is `Python`, `Input` or `Output`, keyword arguments are don't-care.
|
||||
|
||||
Returns:
|
||||
A dict of str to instance of `ScopedValue`.
|
||||
"""
|
||||
return self._node.get_kwargs()
|
||||
|
||||
def set_attribute(self, key: str, value):
|
||||
"""
|
||||
Set attribute of current node.
|
||||
|
||||
Args:
|
||||
key (str): Key of attribute.
|
||||
value (object): Value of attribute.
|
||||
"""
|
||||
self._node.set_attribute(key, value)
|
||||
|
||||
def get_attributes(self) -> {str: object}:
|
||||
"""
|
||||
Get all attributes of current node.
|
||||
|
||||
Returns:
|
||||
A dict of str to instance of object as attributes.
|
||||
"""
|
||||
return self._node.get_attributes()
|
||||
|
||||
def get_attribute(self, key: str):
|
||||
"""
|
||||
Get attribute of current node by key.
|
||||
|
||||
Returns:
|
||||
A object as attribute.
|
||||
"""
|
||||
return self._node.get_attribute(key)
|
||||
|
||||
def __eq__(self, other: 'Node'):
|
||||
return self._node == other._node
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Rewrite module api: NodeType."""
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class NodeType(Enum):
|
||||
"""
|
||||
`NodeType` represents type of `Node`.
|
||||
|
||||
- Unknown: Not inited NodeType.
|
||||
- CallCell: `CallCell` node represents invoking cell-op in forward method.
|
||||
- CallPrimitive: `CallPrimitive` node represents invoking primitive-op in forward method.
|
||||
- CallMethod: `CallMethod` node represents invoking of method in forward method which can not be mapped to
|
||||
cell-op or primitive-op in MindSpore.
|
||||
- Python: `Python` node holds unsupported-ast-node or unnecessary-to-parse-ast-node.
|
||||
- Input: `Input` node represents input of `SymbolTree` corresponding to arguments of forward method.
|
||||
- Output: `Output` node represents output of SymbolTree corresponding to return statement of forward method.
|
||||
- Tree: `Tree` node represents sub-network invoking in forward method.
|
||||
|
||||
"""
|
||||
Unknown = 0
|
||||
# Compute node type
|
||||
CallCell = 1
|
||||
CallPrimitive = 2
|
||||
CallMethod = 3
|
||||
# Other node type
|
||||
Python = 4
|
||||
Input = 5
|
||||
Output = 6
|
||||
Tree = 7
|
|
@ -0,0 +1,413 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""PatternEngine for modifying SymbolTree by pattern."""
|
||||
from collections import OrderedDict
|
||||
from typing import Tuple, Union, List, Type
|
||||
import abc
|
||||
|
||||
from mindspore.nn import Cell
|
||||
from mindspore import log as logger
|
||||
from .node_type import NodeType
|
||||
from .node import Node
|
||||
from .symbol_tree import SymbolTree
|
||||
|
||||
|
||||
class PatternNode:
|
||||
"""
|
||||
`PatternNode` is defined as a node while defining pattern.
|
||||
|
||||
Args:
|
||||
pattern_node_name (str): Name of current node.
|
||||
match_type (Type): A type represents what type would be matched of current node.
|
||||
inputs (list[PatternNode]): Input nodes of current node.
|
||||
"""
|
||||
|
||||
def __init__(self, pattern_node_name: str, match_type: Type = Type[None], inputs: ['PatternNode'] = None):
|
||||
self._name = pattern_node_name
|
||||
self._type = match_type
|
||||
if inputs is None:
|
||||
self._inputs = []
|
||||
else:
|
||||
self._inputs = inputs
|
||||
|
||||
@staticmethod
|
||||
def from_node(node: Node) -> 'PatternNode':
|
||||
"""
|
||||
Create a `PatternNode` from `node`.
|
||||
|
||||
Args:
|
||||
node (Node): Input rewrite node.
|
||||
|
||||
Returns:
|
||||
A `PatternNode` created from `node`.
|
||||
"""
|
||||
|
||||
pattern_node: PatternNode = PatternNode(node.get_targets()[0])
|
||||
if node.get_node_type() is NodeType.CallCell:
|
||||
pattern_node._type = node.get_instance_type()
|
||||
return pattern_node
|
||||
|
||||
@staticmethod
|
||||
def create_pattern_from_node(node: Node) -> 'PatternNode':
|
||||
"""
|
||||
Create a Pattern from `node` with its inputs.
|
||||
|
||||
Args:
|
||||
node (Node): Input rewrite node.
|
||||
|
||||
Returns:
|
||||
A `PatternNode` as root of pattern created from rewrite node.
|
||||
"""
|
||||
|
||||
pattern_node: PatternNode = PatternNode.from_node(node)
|
||||
inputs = []
|
||||
for node_input in node.get_inputs():
|
||||
inputs.append(PatternNode.create_pattern_from_node(node_input))
|
||||
pattern_node._inputs = inputs
|
||||
return pattern_node
|
||||
|
||||
@staticmethod
|
||||
def create_pattern_from_list(type_list: []) -> 'PatternNode':
|
||||
"""
|
||||
Create a Pattern from a cell type list.
|
||||
|
||||
Args:
|
||||
type_list (list[type]): Input cell type list.
|
||||
|
||||
Returns:
|
||||
A `PatternNode` as root of pattern created from cell type list.
|
||||
"""
|
||||
|
||||
last_node = None
|
||||
for i, cell_type in enumerate(type_list):
|
||||
cur_node: PatternNode = PatternNode(str(i) + "-" + str(cell_type), cell_type, [])
|
||||
if last_node is not None:
|
||||
cur_node._inputs = [last_node]
|
||||
else:
|
||||
cur_node._inputs = []
|
||||
last_node = cur_node
|
||||
return last_node
|
||||
|
||||
def add_input(self, node):
|
||||
"""
|
||||
Add an input for current `PatternNode`.
|
||||
|
||||
Args:
|
||||
node (PatternNode): Cell type as an input.
|
||||
"""
|
||||
|
||||
self._inputs.append(node)
|
||||
|
||||
def set_inputs(self, inputs):
|
||||
"""
|
||||
Set inputs for current `PatternNode`.
|
||||
|
||||
Args:
|
||||
inputs (list[PatternNode]) : Inputs to be set as inputs of current `PatternNode`.
|
||||
"""
|
||||
|
||||
self._inputs = inputs
|
||||
|
||||
def match(self, node: Node) -> bool:
|
||||
"""
|
||||
Check if current `PatternNode` can match with `node`.
|
||||
|
||||
Args:
|
||||
node (Node) : A rewrite node to be match.
|
||||
"""
|
||||
|
||||
return self._type == node.get_instance_type()
|
||||
|
||||
def get_inputs(self):
|
||||
"""
|
||||
Getter of inputs.
|
||||
"""
|
||||
|
||||
return self._inputs
|
||||
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Getter of name.
|
||||
"""
|
||||
return self._name
|
||||
|
||||
def type(self):
|
||||
"""
|
||||
Getter of type.
|
||||
"""
|
||||
return self._type
|
||||
|
||||
|
||||
class VarNode(PatternNode):
|
||||
"""
|
||||
VarNode is a subclass of `PatternNode` whose `match` method is always return True.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(VarNode, self).__init__("placeholder", Cell, [])
|
||||
|
||||
def match(self, node: Node) -> bool:
|
||||
return node is not None and node.get_handler() is not None
|
||||
|
||||
|
||||
class Replacement(abc.ABC):
|
||||
"""
|
||||
Interface of replacement function.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
|
||||
"""
|
||||
Interface define for creating replacement nodes from matched result.
|
||||
|
||||
Note:
|
||||
Return value will be delivered into replace api of `SymbolTree` as argument, return value should follow
|
||||
restraint of parameter `new_nodes` of `replace` api if `SymbolTree`. See detail in docstring of `replace`
|
||||
api of `SymbolTree`.
|
||||
|
||||
Args:
|
||||
pattern (PatternNode): A `PatternNode` represents root node of current pattern.
|
||||
is_chain_pattern (bool): A bool indicated if pattern is a chain pattern or a tree pattern.
|
||||
matched (OrderedDict): An OrderedDict map from pattern_node name to node represents matched result.
|
||||
|
||||
Returns:
|
||||
A list of instance of `Node` as replacement nodes.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
|
||||
return self.build(pattern, is_chain_pattern, matched)
|
||||
|
||||
|
||||
class PatternEngine:
|
||||
"""
|
||||
`PatternEngine` is defined how to transform a `SymbolTree` by `PattenNode`.
|
||||
|
||||
Args:
|
||||
pattern (Union[PatternNode, List]): An instance of `PatternNode` or a cell-type-list to construct `PatternNode`
|
||||
as root of a pattern.
|
||||
replacement (callable): A callable define how to generate new_node.
|
||||
"""
|
||||
|
||||
def __init__(self, pattern: Union[PatternNode, List], replacement: Replacement = None):
|
||||
if isinstance(pattern, PatternNode):
|
||||
self._is_chain = False
|
||||
self._replacement: Replacement = replacement
|
||||
self._pattern: PatternNode = pattern
|
||||
elif isinstance(pattern, list):
|
||||
self._is_chain = True
|
||||
self._replacement: Replacement = replacement
|
||||
self._pattern: PatternNode = PatternNode.create_pattern_from_list(pattern)
|
||||
else:
|
||||
raise RuntimeError("Unsupported pattern define")
|
||||
|
||||
def pattern(self) -> PatternNode:
|
||||
"""
|
||||
Getter of pattern.
|
||||
"""
|
||||
|
||||
return self._pattern
|
||||
|
||||
@staticmethod
|
||||
def _multi_to_multi_replace(stree: SymbolTree, old_root: Node, matched_dict: OrderedDict,
|
||||
new_nodes: [Node]) -> Node:
|
||||
"""
|
||||
Replace multi-nodes in `stree` by another list of nodes.
|
||||
|
||||
Note:
|
||||
Call replace api of `SymbolTree`, so parameter `new_nodes` has same restraint with parameter `new_nodes` of
|
||||
`replace` api if `SymbolTree`. See detail in docstring of `replace` api of `SymbolTree`.
|
||||
|
||||
Args:
|
||||
stree (SymbolTree): A `SymbolTree` which replacement will apply on.
|
||||
old_root (Node): A `Node` represents root of original nodes.
|
||||
matched_dict (OrderedDict): An instance of OrderedDict as match result, where key is the pattern name, value
|
||||
is the matched node.
|
||||
new_nodes (list[Node]): A list of instance of Node as replacement.
|
||||
"""
|
||||
to_erase_list = matched_dict.values()
|
||||
# keep all old nodes' inputs
|
||||
inputs_dict = {}
|
||||
for node in to_erase_list:
|
||||
inputs_dict[node.get_name()] = (node.get_inputs())
|
||||
# call replace of SymbolTree
|
||||
new_root = stree.replace(old_root, new_nodes)
|
||||
# replace only support one-to-one replace or one-to-multi replace, we need to erase nodes except
|
||||
# cur_node manually
|
||||
queue: [Node] = [old_root]
|
||||
while queue:
|
||||
cur_node: Node = queue.pop(0)
|
||||
if cur_node in to_erase_list:
|
||||
if cur_node.get_users():
|
||||
# if cur_node is depended on by other node, skip now.
|
||||
# cur_node will be push into queue and be erased later
|
||||
continue
|
||||
if stree.get_node(cur_node.get_name()) is not None:
|
||||
# cur_node is not erased before
|
||||
stree.erase_node(cur_node)
|
||||
queue.extend(inputs_dict.get(cur_node.get_name()))
|
||||
return new_root
|
||||
|
||||
def apply(self, stree: SymbolTree) -> bool:
|
||||
"""
|
||||
Apply current pattern to a `SymbolTree`.
|
||||
|
||||
Note:
|
||||
Sub-tree node will be supported in the near feature.
|
||||
|
||||
Args:
|
||||
stree (SymbolTree): A `SymbolTree` to be transformed.
|
||||
|
||||
Returns:
|
||||
A bool represents if `stree` been changed.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `SymbolTree` has no return node.
|
||||
"""
|
||||
|
||||
root: Node = stree.get_return_node()
|
||||
if root is None:
|
||||
raise RuntimeError("SymbolTree should be inited and has return node")
|
||||
changed = False
|
||||
# IR match
|
||||
queue: [Node] = [root]
|
||||
# Why need visited: we don't need or should not to visit same node multi-times because pre-visited node may
|
||||
# already been erased from SymbolTree.
|
||||
# When will we visit same node multi-times:
|
||||
# a
|
||||
# / \
|
||||
# / \
|
||||
# b c
|
||||
# | |
|
||||
# | d
|
||||
# \ /
|
||||
# \ /
|
||||
# e
|
||||
# 1. Visit e, e does not match pattern, add b, d to queue.
|
||||
# 2. Visit b, b does not match pattern, add a to queue.
|
||||
# 3. Visit d, d does not match pattern, add c to queue.
|
||||
# 4. Visit a, a matches pattern and erased from SymbolTree, add xx to queue.
|
||||
# 5. Visit c, d does not match pattern, add a to queue.
|
||||
# At step 5, a is visited at second time but a is erased from SymbolTree at step 4.
|
||||
visited: [Node] = []
|
||||
while queue:
|
||||
cur_node: Node = queue.pop(0)
|
||||
if cur_node is None: # Because inputs of node is allowed to be None in replacement.
|
||||
continue
|
||||
if cur_node in visited:
|
||||
continue
|
||||
visited.append(cur_node)
|
||||
matched, matched_dict = self._match(self._pattern, cur_node)
|
||||
# not matched
|
||||
if not matched or not PatternEngine._check_match(self._pattern, matched_dict):
|
||||
if cur_node is not None:
|
||||
queue.extend(cur_node.get_inputs())
|
||||
continue
|
||||
# matched
|
||||
new_nodes: [Node] = []
|
||||
if self._replacement is not None:
|
||||
new_nodes: [Node] = self._replacement(self._pattern, self._is_chain, matched_dict)
|
||||
if not new_nodes: # if replacement is empty, do nothing
|
||||
queue.extend(cur_node.get_inputs())
|
||||
else: # replace cur_node with new_nodes
|
||||
changed = True
|
||||
root = PatternEngine._multi_to_multi_replace(stree, cur_node, matched_dict, new_nodes)
|
||||
queue.append(root)
|
||||
return changed
|
||||
|
||||
@staticmethod
|
||||
def _merge_ordered_dict(dict1: OrderedDict, dict2: OrderedDict) -> OrderedDict:
|
||||
"""
|
||||
A static util method to merge two OrderedDict.
|
||||
|
||||
Args:
|
||||
dict1 (OrderedDict): First dict to be merged.
|
||||
dict2 (OrderedDict): Second dict to be merged.
|
||||
|
||||
Returns:
|
||||
Merged OrderedDict.
|
||||
"""
|
||||
|
||||
merged = dict1.copy()
|
||||
merged.update(dict2)
|
||||
return merged
|
||||
|
||||
def _match(self, pattern: PatternNode, node: Node) -> Tuple[bool, OrderedDict]:
|
||||
"""
|
||||
Match `pattern` with a `node` with all inputs of the `pattern`.
|
||||
|
||||
Args:
|
||||
pattern (PatternNode): Pattern to be match.
|
||||
node (Node): Node to be match.
|
||||
|
||||
Returns:
|
||||
A bool value to indicate if matched.
|
||||
An instance of OrderedDict as match result, where key is the pattern name, value is the matched node.
|
||||
"""
|
||||
|
||||
# Don't iterate into subgraph node, pattern should not be matched across sub-tree
|
||||
if node.get_node_type() != NodeType.CallCell and node.get_node_type() != NodeType.Input:
|
||||
logger.debug("Pattern match failed: node(%s) is not a cell", str(node))
|
||||
return False, OrderedDict()
|
||||
if not pattern.match(node):
|
||||
logger.debug("Pattern match failed: node(%s)'s type is %s while pattern type is %s", str(node),
|
||||
node.get_instance_type(), pattern.type())
|
||||
return False, OrderedDict()
|
||||
if isinstance(pattern, VarNode):
|
||||
return True, OrderedDict()
|
||||
pattern_inputs = pattern.get_inputs()
|
||||
cur_inputs = node.get_inputs()
|
||||
input_num = len(pattern_inputs)
|
||||
if input_num == 0:
|
||||
return True, OrderedDict({pattern.name(): node})
|
||||
if input_num != len(cur_inputs):
|
||||
logger.debug("Pattern match failed: node(%s)'s has %d inputs while pattern has %d inputs", str(node),
|
||||
len(node.get_inputs()), input_num)
|
||||
return False, OrderedDict()
|
||||
result = OrderedDict()
|
||||
for i in range(0, input_num):
|
||||
is_matched, tmp_result = self._match(pattern_inputs[i], cur_inputs[i])
|
||||
if not is_matched:
|
||||
return False, OrderedDict()
|
||||
result = PatternEngine._merge_ordered_dict(result, tmp_result)
|
||||
result[pattern.name()] = node
|
||||
return True, result
|
||||
|
||||
@staticmethod
|
||||
def _check_match(pattern: PatternNode, match_dict: OrderedDict) -> bool:
|
||||
"""
|
||||
Check if matched result is a leak result.
|
||||
|
||||
A leak result means that the result is matched the `pattern`, but some nodes in result which is not
|
||||
corresponding to root of pattern have outputs used by nodes outside of result.
|
||||
|
||||
Args:
|
||||
pattern (PatternNode): A `PatternNode` represents pattern to be match.
|
||||
match_dict (OrderedDict): A OrderedDict represents matched result.
|
||||
|
||||
Returns:
|
||||
A bool value to indicate if matched result leaked.
|
||||
"""
|
||||
matched_nodes = match_dict.values()
|
||||
for key in match_dict:
|
||||
if key == pattern.name():
|
||||
continue
|
||||
node: Node = match_dict[key]
|
||||
for output in node.get_users():
|
||||
if output not in matched_nodes:
|
||||
logger.debug("Check match failed, pattern leaked")
|
||||
return False
|
||||
return True
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Rewrite module api: ValueType and ScopedValue."""
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ValueType(Enum):
|
||||
"""
|
||||
ValueType represents type of `ScopedValue`.
|
||||
|
||||
- A `NamingValue` represents a reference to another variable.
|
||||
- A `CustomObjValue` represents an instance of custom class or an object whose type is out of range of base-type
|
||||
and container-type of ValueType.
|
||||
"""
|
||||
|
||||
# base type
|
||||
StringValue = 0
|
||||
IntValue = 1
|
||||
FloatValue = 2
|
||||
# container type
|
||||
TupleValue = 20
|
||||
ListValue = 21
|
||||
DictValue = 22
|
||||
# other type
|
||||
NamingValue = 40
|
||||
CustomObjValue = 41
|
||||
|
||||
|
||||
class ScopedValue:
|
||||
"""
|
||||
`ScopedValue` represents a value with its full-scope.
|
||||
|
||||
`ScopedValue` is used to express: a left-value such as target of an assign statement, or a callable object such as
|
||||
func of a call statement, or a right-value such as args and kwargs of an assign statement.
|
||||
|
||||
Args:
|
||||
arg_type (ValueType): A `ValueType` represents type of current value.
|
||||
scope (str): A string represents scope of current value. Take "self.var1" as an example, `scope` of this
|
||||
var1 is "self".
|
||||
value: A handler represents value of current value. The type of value is corresponding to `arg_type`.
|
||||
"""
|
||||
|
||||
def __init__(self, arg_type: ValueType, scope: str = "", value=None):
|
||||
self.type = arg_type
|
||||
self.scope = scope
|
||||
self.value = value
|
||||
|
||||
@classmethod
|
||||
def create_variable_value(cls, value) -> Optional['ScopedValue']:
|
||||
"""
|
||||
Create `ScopedValue` from a variable.
|
||||
|
||||
`ScopedValue`'s type is determined by type of value. `ScopedValue`'s scope is empty.
|
||||
|
||||
Args:
|
||||
value: The value to be converted to `ScopedValue`.
|
||||
|
||||
Returns:
|
||||
An instance of `ScopedValue`.
|
||||
"""
|
||||
if isinstance(value, int):
|
||||
return cls(ValueType.IntValue, "", value)
|
||||
if isinstance(value, float):
|
||||
return cls(ValueType.FloatValue, "", value)
|
||||
if isinstance(value, str):
|
||||
return cls(ValueType.StringValue, "", value)
|
||||
if isinstance(value, tuple):
|
||||
return cls(ValueType.TupleValue, "",
|
||||
tuple(cls.create_variable_value(single_value) for single_value in value))
|
||||
if isinstance(value, list):
|
||||
return cls(ValueType.ListValue, "", list(cls.create_variable_value(single_value) for single_value in value))
|
||||
if isinstance(value, dict):
|
||||
for key, _ in value.items():
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("key should be str, got: ", type(key))
|
||||
return cls(ValueType.DictValue, "",
|
||||
dict((key, cls.create_variable_value(single_value)) for key, single_value in value.items()))
|
||||
return cls(ValueType.CustomObjValue, "", value)
|
||||
|
||||
@classmethod
|
||||
def create_naming_value(cls, name: str, scope: str = "") -> 'ScopedValue':
|
||||
"""
|
||||
Create a naming `ScopedValue`. A `NamingValue` represents a reference to another variable.
|
||||
|
||||
Args:
|
||||
name: (str): A string represents the identifier of another variable.
|
||||
scope: (str): A string represents the scope of another variable.
|
||||
|
||||
Returns:
|
||||
An instance of `ScopedValue`.
|
||||
"""
|
||||
return cls(ValueType.NamingValue, scope, name)
|
||||
|
||||
@staticmethod
|
||||
def create_name_values(names: [str], scopes: [str] = None) -> ['ScopedValue']:
|
||||
"""
|
||||
Create a list of naming `ScopedValue`.
|
||||
|
||||
Args:
|
||||
names: (list[str]): A list of string represents names of referenced variables.
|
||||
scopes: (list[str]): A list of string represents scopes of referenced variables.
|
||||
|
||||
Returns:
|
||||
An list of instance of `ScopedValue`.
|
||||
|
||||
Raise:
|
||||
RuntimeError: If the length of names is not equal to the length of scopes when scopes are not None.
|
||||
"""
|
||||
if scopes is not None:
|
||||
if len(names) != len(scopes):
|
||||
raise RuntimeError("Length of names should be equal to length of scopes")
|
||||
result = []
|
||||
for index, name in enumerate(names):
|
||||
if scopes is not None:
|
||||
scope = scopes[index]
|
||||
else:
|
||||
scope = ""
|
||||
result.append(ScopedValue.create_naming_value(name, scope))
|
||||
return result
|
||||
|
||||
def __str__(self):
|
||||
if self.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
|
||||
return str(self.value)
|
||||
if self.type == ValueType.NamingValue:
|
||||
return f"{self.scope}.{self.value}" if self.scope else str(self.value)
|
||||
if self.type == ValueType.CustomObjValue:
|
||||
return f"CustomObj: {str(self.value)}"
|
||||
return f"Illegal ValueType: {str(self.type)}"
|
||||
|
||||
def __eq__(self, other):
|
||||
if id(self) == id(other):
|
||||
return True
|
||||
return self.type == other.type and self.scope == other.scope and self.value == other.value
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.type, self.scope, self.value))
|
|
@ -0,0 +1,227 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Rewrite module api: SymbolTree."""
|
||||
from typing import Optional
|
||||
|
||||
from mindspore.nn import Cell
|
||||
from .node import Node
|
||||
from ..symbol_tree_builder import SymbolTreeBuilder
|
||||
from ..symbol_tree import SymbolTree as SymbolTreeImpl
|
||||
|
||||
|
||||
class SymbolTree:
|
||||
"""
|
||||
A `SymbolTree` usually corresponding to forward method of a network.
|
||||
|
||||
Args:
|
||||
network (Cell): Network to be rewritten. Only support `Cell`-type-network now.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `network` is not a Cell.
|
||||
RuntimeError: If there is any unsupported ast node type while parsing or optimizing.
|
||||
"""
|
||||
|
||||
def __init__(self, network: Cell):
|
||||
self._symbol_tree: SymbolTreeImpl = SymbolTreeBuilder(network).build()
|
||||
|
||||
def get_handler(self) -> SymbolTreeImpl:
|
||||
"""
|
||||
Get handler of `SymbolTree` implementation.
|
||||
|
||||
Returns:
|
||||
An instance of `SymbolTree`.
|
||||
"""
|
||||
return self._symbol_tree
|
||||
|
||||
def nodes(self) -> {}:
|
||||
"""
|
||||
Get all nodes of corresponding network.
|
||||
|
||||
Returns:
|
||||
A dict mapping from name of node to node.
|
||||
"""
|
||||
return [Node(node_impl) for node_impl in self._symbol_tree.nodes(unfold_subtree=True)]
|
||||
|
||||
def get_node(self, node_name: str) -> Optional[Node]:
|
||||
"""
|
||||
Get node by `node_name`.
|
||||
|
||||
Args:
|
||||
node_name (str): A string represents name of node.
|
||||
|
||||
Returns:
|
||||
An instance of node if find else None.
|
||||
"""
|
||||
node_impl = self._symbol_tree.get_node(node_name)
|
||||
if node_impl is None:
|
||||
return None
|
||||
return Node(node_impl)
|
||||
|
||||
def get_return_node(self) -> Node:
|
||||
"""
|
||||
Get return node of current `SymbolTree`.
|
||||
|
||||
A `SymbolTree` can and should have one return node corresponding to return statement of forward method of
|
||||
network.
|
||||
|
||||
Returns:
|
||||
An instance of node represents return node.
|
||||
"""
|
||||
ret = self._symbol_tree.get_return_node()
|
||||
if ret is None:
|
||||
raise RuntimeError("SymbolTree is not well inited, can not find return node.")
|
||||
return Node(ret)
|
||||
|
||||
def before(self, node: Node):
|
||||
"""
|
||||
Get insert position before input `node`.
|
||||
|
||||
`Position` is used to indicate where to insert node, it indicates position in source code rather than position
|
||||
in topological order. We don't need to care about what `Position` is, just treat it as a handler and use it as
|
||||
an arguments of `insert` api of `SymbolTree`.
|
||||
|
||||
Args:
|
||||
node (Node): Indicate the position before which node. Can be a node or name of node.
|
||||
|
||||
Returns:
|
||||
A `Position` to indicate where to insert node.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if `node` is not a Node or a string.
|
||||
"""
|
||||
return self._symbol_tree.before(node.get_handler())
|
||||
|
||||
def after(self, node: Node):
|
||||
"""
|
||||
Get insert position after input `node`.
|
||||
|
||||
`Position` is used to indicate where to insert node, it indicates position in source code rather than position
|
||||
in topological order. We don't need to care about what `Position` is, just treat it as a handler and use it as
|
||||
an arguments of `insert` api of `SymbolTree`.
|
||||
|
||||
Args:
|
||||
node (Node): Indicate the position after which node. Can be a node or name of node.
|
||||
|
||||
Returns:
|
||||
A `Position` to indicate where to insert node.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `node` is not a Node.
|
||||
"""
|
||||
return self._symbol_tree.after(node.get_handler())
|
||||
|
||||
def insert(self, position, node: Node) -> Node:
|
||||
"""
|
||||
Insert a `node` into `SymbolTree` at `position`.
|
||||
|
||||
`position` is obtained from `before` api or `after` api of `SymbolTree`.
|
||||
|
||||
Args:
|
||||
position (Position): Indicate where to insert `node`.
|
||||
node (Node): An instance of Node to be inserted.
|
||||
|
||||
Returns:
|
||||
An instance of Node being inserted. `node` could be changed while calling this method for uniqueness and
|
||||
custom-object in args or kwargs.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `position` is not belong to current `SymbolTree`.
|
||||
"""
|
||||
return Node(self._symbol_tree.insert_node(position, node.get_handler()))
|
||||
|
||||
def erase_node(self, node: Node) -> Optional[Node]:
|
||||
"""
|
||||
Erase a `node` from rewrite. Can only erase a node not being depended on.
|
||||
|
||||
Args:
|
||||
node (Node): A `Node` to be erased. Can be a node or name of node.
|
||||
|
||||
Returns:
|
||||
An instance of `Node` being erased if node is in `SymbolTree` else None.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `node` is not a `Node`.
|
||||
"""
|
||||
return Node(self._symbol_tree.erase_node(node.get_handler()))
|
||||
|
||||
def replace(self, old_node: Node, new_nodes: [Node]) -> Node:
|
||||
"""
|
||||
Replace `old_node` with a node_tree.
|
||||
|
||||
Note:
|
||||
|
||||
1. Replace support one-to-one replacement or one-to-multi replacement. If you need multi-to-multi
|
||||
replacement, please refer to `PatternEngine`.
|
||||
2. When applying one-to-multi replacement, Rewrite will insert all `new_nodes` into symbol_tree.
|
||||
3. Caller should maintain arguments and targets of nodes intra sub-tree for specifying topological relation
|
||||
intra sub-tree.
|
||||
4. Caller should maintain arguments of input nodes of sub-tree and for specifying topological relation of
|
||||
inputs of sub-tree.
|
||||
5. Rewrite will maintain arguments of prepend node of sub-tree for specifying topological relation of
|
||||
outputs of sub-tree.
|
||||
6. Rewrite will maintain all inputs of nodes after replace `new_nodes` into `SymbolTree`.
|
||||
|
||||
Args:
|
||||
old_node (Node): Node to be replaced.
|
||||
new_nodes (list[Node]): Nodes of the node_tree to replace in.
|
||||
|
||||
Returns:
|
||||
An instance of Node represents root of node_tree been replaced in.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Old node is isolated.
|
||||
"""
|
||||
nodes_impl = [node.get_handler() for node in new_nodes]
|
||||
return Node(self._symbol_tree.replace(old_node.get_handler(), nodes_impl))
|
||||
|
||||
def set_output(self, index: int, return_value: str) -> Node:
|
||||
"""
|
||||
Set return value of network.
|
||||
|
||||
Args:
|
||||
index (int): Indicate which output being modified.
|
||||
return_value (str): New return value to been set.
|
||||
|
||||
Returns:
|
||||
Return node of current rewrite.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `index` is out of range.
|
||||
"""
|
||||
return Node(self._symbol_tree.set_output(return_value, index))
|
||||
|
||||
def dump(self):
|
||||
"""
|
||||
Dump graph to console.
|
||||
"""
|
||||
self._symbol_tree.dump()
|
||||
|
||||
def get_code(self) -> str:
|
||||
"""
|
||||
Get source code of modified network.
|
||||
|
||||
Returns:
|
||||
A str represents source code of modified network.
|
||||
"""
|
||||
return self._symbol_tree.get_code()
|
||||
|
||||
def get_network(self) -> Cell:
|
||||
"""
|
||||
Get modified network.
|
||||
|
||||
Returns:
|
||||
A network object.
|
||||
"""
|
||||
return self._symbol_tree.get_network()
|
|
@ -0,0 +1,292 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Ast utils for create or update ast node."""
|
||||
from typing import Optional
|
||||
import ast
|
||||
|
||||
from .api.scoped_value import ScopedValue, ValueType
|
||||
|
||||
|
||||
class AstModifier(ast.NodeTransformer):
|
||||
"""Ast utils for create or update ast node."""
|
||||
|
||||
@staticmethod
|
||||
def erase_ast_from_function(ast_func: ast.FunctionDef, to_erase: ast.AST) -> bool:
|
||||
"""
|
||||
Erase ast node from ast.FunctionDef.
|
||||
|
||||
Args:
|
||||
ast_func (ast.FunctionDef): From which to search to_erase-node and erase.
|
||||
to_erase (ast.AST): Node to be erased.
|
||||
|
||||
Returns:
|
||||
A bool if to_erase-node been found and been erased.
|
||||
"""
|
||||
for body in ast_func.body:
|
||||
if id(body) == id(to_erase):
|
||||
ast_func.body.remove(body)
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def insert_assign_to_function(ast_func: ast.FunctionDef, targets: [ScopedValue], expr: ScopedValue,
|
||||
args: [ScopedValue] = None, kwargs: {str, ScopedValue}=None,
|
||||
index_ast: Optional[ast.AST] = None, insert_before=True) -> ast.AST:
|
||||
"""
|
||||
Insert an ast.Assign into an ast.FunctionDef.
|
||||
|
||||
Args:
|
||||
ast_func (ast.FunctionDef): Where new ast.Assign to be inserted into.
|
||||
targets ([ScopedValue]): Targets of ast.Assign.
|
||||
expr (ScopedValue): Func of ast.Call which is value of new ast.Assign.
|
||||
args ([ScopedValue]): Args of ast.Call which is value of new ast.Assign.
|
||||
kwargs ({str, ScopedValue}): Kwargs of ast.Call which is value of new ast.Assign.
|
||||
index_ast (Optional[ast.AST]): An ast_node indicates a position in 'ast_func' where new ast.Assign node to
|
||||
be inserted into. Default is None which means append new ast.Assign into
|
||||
'ast_func'.
|
||||
insert_before (bool): A bool indicates at before or at after of 'index_ast' where new ast.Assign node to be
|
||||
inserted into. Only valid when 'index_ast' is not None. Default is True which means
|
||||
inserting new ast.Assign before 'index_ast'.
|
||||
|
||||
Returns:
|
||||
An instance of ast.Assign which has been inserted into 'ast_func'.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If 'index_ast' is not contained in 'ast_func'.
|
||||
"""
|
||||
assign = AstModifier.create_call_assign(targets, expr, args, kwargs)
|
||||
return AstModifier.insert_assign_ast_to_function(ast_func, assign, index_ast, insert_before)
|
||||
|
||||
@staticmethod
|
||||
def insert_assign_ast_to_function(ast_func: ast.FunctionDef, ast_assign: ast.Assign,
|
||||
index_ast: Optional[ast.AST] = None, insert_before=True) -> ast.AST:
|
||||
"""
|
||||
Insert an ast.Assign into an ast.FunctionDef.
|
||||
|
||||
Args:
|
||||
ast_func (ast.FunctionDef): Where new ast.Assign to be inserted into.
|
||||
ast_assign (ast.Assign): An instance of ast.Assign to be inserted in.
|
||||
index_ast (Optional[ast.AST]): An ast_node indicates a position in 'ast_func' where new ast.Assign node to
|
||||
be inserted into. Default is None which means append new ast.Assign to
|
||||
'ast_func'.
|
||||
insert_before (bool): A bool indicates at before or at after of 'index_ast' where new ast.Assign node to be
|
||||
inserted into. Only valid when 'index_ast' is not None. Default is True which means
|
||||
inserting new ast.Assign before 'index_ast'.
|
||||
|
||||
Returns:
|
||||
An instance of ast.Assign which has been inserted into 'ast_func'.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If 'index_ast' is not contained in 'ast_func'.
|
||||
"""
|
||||
if index_ast is None:
|
||||
ast_func.body.append(ast_assign)
|
||||
ast.fix_missing_locations(ast_func)
|
||||
return ast_assign
|
||||
for index in range(0, len(ast_func.body)):
|
||||
if id(ast_func.body[index]) == id(index_ast):
|
||||
if insert_before:
|
||||
ast_func.body.insert(index, ast_assign)
|
||||
else:
|
||||
ast_func.body.insert(index + 1, ast_assign)
|
||||
ast.fix_missing_locations(ast_func)
|
||||
return ast_assign
|
||||
raise RuntimeError("index_ast is not contained in ast_func")
|
||||
|
||||
@staticmethod
|
||||
def append_global_vars_expr_to_init(init_func: ast.FunctionDef, targets: [ScopedValue],
|
||||
field: str) -> ast.AST:
|
||||
"""
|
||||
Append an ast.Assign to an ast.FunctionDef which is function named "__init__" in network. Value of new
|
||||
ast.Assign is an ast.Call represents get an object from global_vars dict.
|
||||
|
||||
While user inserting a custom op, the instance of new custom op is saved in a dict named global_vars. Rewrite
|
||||
need to get the custom op instance from global_vars in new "__init__" function of network:
|
||||
self.var1 = global_vars.get("var1")
|
||||
|
||||
Args:
|
||||
init_func (ast.FunctionDef): An instance of ast.FunctionDef which is "__init__" function of network.
|
||||
targets ([ScopedValue]): Targets of ast.Assign.
|
||||
field (str): A string represents name of new custom op field.
|
||||
|
||||
Returns:
|
||||
An instance of ast.Assign which has been appended to 'init_func'.
|
||||
"""
|
||||
return AstModifier.insert_assign_to_function(init_func, targets=targets,
|
||||
args=[ScopedValue.create_variable_value(field)],
|
||||
expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"))
|
||||
|
||||
@staticmethod
|
||||
def create_call_assign(targets: [ScopedValue], expr: ScopedValue, args: [ScopedValue],
|
||||
kwargs: {str, ScopedValue}) -> ast.Assign:
|
||||
"""
|
||||
Create an instance of ast.Assign whose value must ba a ast.Call.
|
||||
|
||||
Args:
|
||||
targets ([ScopedValue]): Targets of ast.Assign.
|
||||
expr (ScopedValue): Func of ast.Call which is value of new ast.Assign.
|
||||
args ([ScopedValue]): Args of ast.Call which is value of new ast.Assign.
|
||||
kwargs ({str, ScopedValue}): Kwargs of ast.Call which is value of new ast.Assign.
|
||||
|
||||
Returns:
|
||||
An instance of ast.Assign.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If 'targets' is None.
|
||||
RuntimeError: If value_type of element of 'targets' is not ValueType.NamingValue.
|
||||
|
||||
RuntimeError: If length of 'targets' is not 1. Multi-targets will be support in the future.
|
||||
"""
|
||||
if targets is None or len(targets) != 1:
|
||||
raise RuntimeError("Only support one target in insert_cell_to_init now")
|
||||
if targets[0].type != ValueType.NamingValue:
|
||||
raise RuntimeError("Target must be a right-value, got: ", targets[0])
|
||||
if targets[0].scope:
|
||||
ast_target = ast.Attribute(ast.Name(targets[0].scope, ast.Load()), targets[0].value, ast.Store())
|
||||
else:
|
||||
ast_target = ast.Name(targets[0].value, ast.Store())
|
||||
call = AstModifier.create_call(expr, args, kwargs)
|
||||
result = ast.Assign(targets=[ast_target], value=call)
|
||||
ast.fix_missing_locations(result)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def create_call(expr: ScopedValue, args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None) -> ast.Call:
|
||||
"""
|
||||
Create an instance of ast.Call.
|
||||
|
||||
Args:
|
||||
expr (ScopedValue): Func of ast.Call.
|
||||
args ([ScopedValue]): Args of ast.Call.
|
||||
kwargs ({str, ScopedValue}): Kwargs of ast.Call.
|
||||
|
||||
Returns:
|
||||
An instance of ast.Call.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If value_type of 'expr' is ValueType.CustomObjValue.
|
||||
RuntimeError: If value_type of 'expr' is not ValueType.NamingValue.
|
||||
RuntimeError: If value_type of element of 'args' is ValueType.CustomObjValue.
|
||||
RuntimeError: If value_type of value of 'kwargs' is ValueType.CustomObjValue.
|
||||
TypeError: If expr is not an instance of ScopedValue.
|
||||
RuntimeError: If element of 'args' is not an instance of ScopedValue.
|
||||
RuntimeError: If value of 'kwargs' is not an instance of ScopedValue.
|
||||
"""
|
||||
if not isinstance(expr, ScopedValue):
|
||||
raise TypeError("expr should be ScopedValue, got: ", type(expr))
|
||||
if expr.type == ValueType.CustomObjValue:
|
||||
raise RuntimeError("Please handle custom-object first")
|
||||
if expr.type != ValueType.NamingValue:
|
||||
raise RuntimeError("Expr must not be a constant, because constant can not been called: ", expr.type)
|
||||
if expr.scope:
|
||||
ast_func = ast.Attribute(ast.Name(expr.scope, ast.Load()), expr.value, ast.Store())
|
||||
else:
|
||||
ast_func = ast.Name(expr.value, ast.Store())
|
||||
|
||||
ast_args = []
|
||||
if args is not None:
|
||||
for arg in args:
|
||||
if not isinstance(arg, ScopedValue):
|
||||
raise TypeError("arg should be ScopedValue, got: ", type(arg))
|
||||
if arg.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
|
||||
if arg.scope:
|
||||
raise RuntimeError("arg.scope should be empty")
|
||||
ast_args.append(ast.Constant(value=arg.value, kind=None))
|
||||
elif arg.type == ValueType.NamingValue:
|
||||
if arg.scope:
|
||||
ast_args.append(ast.Attribute(ast.Name(arg.scope, ast.Load()), arg.value, ast.Store()))
|
||||
else:
|
||||
ast_args.append(ast.Name(arg.value, ast.Store()))
|
||||
else:
|
||||
raise RuntimeError("Please handle custom-object first")
|
||||
keywords = []
|
||||
if kwargs is not None:
|
||||
for arg, value in kwargs.items():
|
||||
if not isinstance(value, ScopedValue):
|
||||
raise TypeError("value should be ScopedValue, got: ", type(value))
|
||||
if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
|
||||
if value.scope:
|
||||
raise RuntimeError("value.scope should be empty")
|
||||
keywords.append(ast.keyword(arg=arg, value=ast.Constant(value=value.value, kind=None)))
|
||||
elif value.type == ValueType.NamingValue:
|
||||
if value.scope:
|
||||
keywords.append(ast.keyword(arg=arg,
|
||||
value=ast.Attribute(ast.Name(value.scope, ast.Load()), value.value,
|
||||
ast.Store())))
|
||||
else:
|
||||
keywords.append(ast.keyword(arg=arg, value=ast.Name(value.value, ast.Store())))
|
||||
else:
|
||||
raise RuntimeError("Please handle custom-object first")
|
||||
result = ast.Call(func=ast_func, args=ast_args, keywords=keywords)
|
||||
ast.fix_missing_locations(result)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def update_arg_value(src_argument: ScopedValue, dst_ast: ast.AST):
|
||||
"""
|
||||
Update 'arg_value' by 'input_argument'
|
||||
|
||||
Args:
|
||||
src_argument (ScopedValue): An instance of ScopedValue represents new argument.
|
||||
dst_ast (ast.AST): Targets of ast.Assign.
|
||||
|
||||
Raises:
|
||||
TypeError: Input src_argument is not a ScopedValue
|
||||
RuntimeError: If 'dst_ast' is an instance of ast.Constant but type of 'src_argument' is not
|
||||
ValueType.IntValue, ValueType.FloatValue or ValueType.StringValue.
|
||||
RuntimeError: If 'dst_ast' is an instance of ast.Name or ast.Attribute but type of 'src_argument' is not
|
||||
ValueType.NamingValue.
|
||||
RuntimeError: When 'dst_ast' is an instance of ast.Name, scope of 'src_argument' is not empty.
|
||||
RuntimeError: When 'dst_ast' is an instance of ast.Attribute, value of 'dst_ast' is not an instance of
|
||||
ast.Name.
|
||||
RuntimeError: If 'dst_ast' is an instance of ast.Tuple but type of 'src_argument' is not
|
||||
ValueType.TupleValue.
|
||||
RuntimeError: If 'dst_ast' is an instance of ast.Constant, ast.Name, ast.Attribute or ast.Tuple.
|
||||
RuntimeError: When 'dst_ast' is an instance of ast.Tuple, length of elts of 'dst_ast' is not equal to length
|
||||
of value of 'src_argument'.
|
||||
"""
|
||||
if not isinstance(src_argument, ScopedValue):
|
||||
raise TypeError("src_argument should be ScopedValue, got: ", type(src_argument))
|
||||
if isinstance(dst_ast, ast.Constant):
|
||||
if src_argument.type not in [ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue]:
|
||||
raise RuntimeError("src_argument should be a IntValue, FloatValue or StringValue, got:",
|
||||
str(src_argument.type))
|
||||
dst_ast.value = src_argument.value
|
||||
return
|
||||
if isinstance(dst_ast, ast.Name):
|
||||
if src_argument.type != ValueType.NamingValue:
|
||||
raise RuntimeError("src_argument.type should equal to ValueType.NamingValue")
|
||||
if src_argument.scope:
|
||||
raise RuntimeError("src_argument.scope should be empty")
|
||||
dst_ast.id = src_argument.value
|
||||
return
|
||||
if isinstance(dst_ast, ast.Attribute):
|
||||
if src_argument.type != ValueType.NamingValue:
|
||||
raise RuntimeError("src_argument.type should equal to ValueType.NamingValue")
|
||||
attr_value = dst_ast.value
|
||||
if not isinstance(attr_value, ast.Name):
|
||||
raise RuntimeError("Only support ast.Name as value of attribute ", type(attr_value))
|
||||
attr_value.id = src_argument.scope
|
||||
dst_ast.attr = src_argument.value
|
||||
return
|
||||
if isinstance(dst_ast, ast.Tuple):
|
||||
if src_argument.type != ValueType.TupleValue:
|
||||
raise RuntimeError("src_argument.type should equal to ValueType.TupleValue")
|
||||
if len(src_argument.value) != len(dst_ast.elts):
|
||||
raise RuntimeError("src_argument.value and elts in ast should have same length")
|
||||
for elt_index, elt in enumerate(dst_ast.elts):
|
||||
AstModifier.update_arg_value(src_argument.value[elt_index], elt)
|
||||
return
|
||||
raise RuntimeError("keyword value type is not supported", type(dst_ast))
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Transformers for optimizing ast."""
|
||||
from .flatten_recursive_stmt import FlattenRecursiveStmt
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Ast optimizer for flatten recursive call."""
|
||||
from typing import Any, Tuple
|
||||
import ast
|
||||
from ast import FunctionDef
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
class FlattenRecursiveStmt(ast.NodeTransformer):
|
||||
"""Ast optimizer for flatten recursive call."""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Constructor of FlattenRecursiveStmt.
|
||||
|
||||
Returns:
|
||||
An instance of ast optimizer for flatten recursive call.
|
||||
"""
|
||||
self._flatten_table: dict = {
|
||||
ast.Return: ["value"],
|
||||
ast.Call: ["args"],
|
||||
ast.BinOp: ["left", "right"],
|
||||
ast.BoolOp: ["values"],
|
||||
ast.unaryop: ["operand"],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _generate_target_name(node: ast.AST, target_names):
|
||||
"""Generate unique target name."""
|
||||
if isinstance(node, ast.Call):
|
||||
func = node.func
|
||||
if isinstance(func, ast.Name):
|
||||
target_name = func.id
|
||||
elif isinstance(func, ast.Attribute):
|
||||
target_name = func.attr
|
||||
else:
|
||||
logger.warning("unhandled type of func of ast.Call while generating new target name: %s ", type(func))
|
||||
target_name = "function"
|
||||
elif isinstance(node, ast.Return):
|
||||
target_name = "return_value"
|
||||
elif isinstance(node, (ast.BinOp, ast.boolop, ast.UnaryOp)):
|
||||
target_name = type(node.op).__name__
|
||||
else:
|
||||
logger.warning("unhandled type of node while generating new target name: %s ", type(node))
|
||||
target_name = type(node).__name__
|
||||
suffix = 0
|
||||
result = target_name
|
||||
while result in target_names:
|
||||
suffix += 1
|
||||
result = f"{target_name}_{suffix}"
|
||||
target_names.append(result)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _fill_in_original_target_names(target_names, node):
|
||||
"""Fill in original target names before getting unique names."""
|
||||
for function_index in range(len(node.body)):
|
||||
child = node.body[function_index]
|
||||
if not isinstance(child, ast.Assign):
|
||||
continue
|
||||
targets = child.targets
|
||||
for target in targets:
|
||||
if not isinstance(target, ast.Name):
|
||||
raise RuntimeError("currently only support ast.Name targets")
|
||||
target_name = target.id
|
||||
if target_name not in target_names:
|
||||
target_names.append(target_name)
|
||||
|
||||
@staticmethod
|
||||
def _create_new_assign_node(node: ast.AST, target_names) -> Tuple[str, ast.AST]:
|
||||
"""Create new assign node to be inserted into ast.FunctionDef."""
|
||||
if isinstance(node, (ast.Name, ast.Constant, ast.Num, ast.Str, ast.NameConstant, ast.Bytes, ast.Ellipsis)):
|
||||
return "", node
|
||||
new_target_name = FlattenRecursiveStmt._generate_target_name(node, target_names)
|
||||
return new_target_name, ast.Assign(targets=[ast.Name(id=new_target_name, ctx=ast.Store())], value=node)
|
||||
|
||||
def _flatten_statement(self, node: ast.AST, target_names) -> [ast.AST]:
|
||||
"""Flatten recursive statement according to different node type."""
|
||||
flatten_config = self._flatten_table.get(type(node))
|
||||
if flatten_config is None:
|
||||
return []
|
||||
results = []
|
||||
for todo_name in flatten_config:
|
||||
todos = getattr(node, todo_name)
|
||||
if isinstance(todos, list):
|
||||
new_list = []
|
||||
for todo in todos:
|
||||
new_target_name, new_node = FlattenRecursiveStmt._create_new_assign_node(todo, target_names)
|
||||
if id(new_node) == id(todo):
|
||||
new_list.append(todo)
|
||||
else:
|
||||
new_list.append(ast.Name(id=new_target_name, ctx=ast.Load()))
|
||||
results.append(new_node)
|
||||
setattr(node, todo_name, new_list)
|
||||
elif isinstance(todos, dict):
|
||||
new_dict = []
|
||||
for key, value in todos:
|
||||
new_target_name, new_node = FlattenRecursiveStmt._create_new_assign_node(value, target_names)
|
||||
if id(new_node) == id(value):
|
||||
new_dict[key] = value
|
||||
else:
|
||||
new_dict[key] = ast.Name(id=new_target_name, ctx=ast.Load())
|
||||
results.append(new_node)
|
||||
setattr(node, todo_name, new_dict)
|
||||
else:
|
||||
new_target_name, new_node = FlattenRecursiveStmt._create_new_assign_node(todos, target_names)
|
||||
if id(new_node) != id(todos):
|
||||
setattr(node, todo_name, ast.Name(id=new_target_name, ctx=ast.Load()))
|
||||
results.append(new_node)
|
||||
return results
|
||||
|
||||
def visit_FunctionDef(self, node: FunctionDef) -> Any:
|
||||
"""Traverse construct node and flatten recursive nodes."""
|
||||
if node.name != "construct":
|
||||
return node
|
||||
|
||||
target_names = []
|
||||
self._fill_in_original_target_names(target_names, node)
|
||||
index = len(node.body) - 1
|
||||
while index >= 0:
|
||||
child = node.body[index]
|
||||
if isinstance(child, ast.Assign):
|
||||
stmt = child.value
|
||||
elif isinstance(child, ast.Expr):
|
||||
stmt = child.value
|
||||
else:
|
||||
stmt = child
|
||||
results = self._flatten_statement(stmt, target_names)
|
||||
if results:
|
||||
results.reverse()
|
||||
for result in results:
|
||||
node.body.insert(index, result)
|
||||
index += 1
|
||||
index -= 1
|
||||
return node
|
||||
|
||||
def transform(self, ast_root):
|
||||
"""Interface of FlattenRecursiveStmt."""
|
||||
ast_root = self.visit(ast_root)
|
||||
ast_root = ast.fix_missing_locations(ast_root)
|
||||
return ast_root
|
|
@ -0,0 +1,159 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Unique name producer for target, name of node."""
|
||||
from typing import Union
|
||||
|
||||
from .node import Node
|
||||
from .api.node_type import NodeType
|
||||
|
||||
|
||||
class Namer:
|
||||
"""
|
||||
Used for unique identity in a class-scope. current used for target of construct-function.
|
||||
Namer records times of name been used, and add prefix to origin name for unique name. For example, when a Namer
|
||||
record "name1" has been used 10 times, when a new request require a unique name base on 'name1', namer will respond
|
||||
"name1_10" as unique name.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Constructor of Namer."""
|
||||
self._names: {str: int} = {}
|
||||
|
||||
@staticmethod
|
||||
def _real_name(name: str) -> str:
|
||||
"""
|
||||
Find real name. For example, "name1" is the real name of "name1_10", "name1" is the real name of "name1_10_3".
|
||||
If not find real name before find unique name, unique name may be not unique. For example:
|
||||
|
||||
1. "name1" has been used 10 times, which means "name1", "name1_2", "name1_3" ... "name1_10" has been used;
|
||||
2. new request require a unique name base on 'name1_5'
|
||||
3. If namer not find real name of "name1_5", namer will find that "name1_5" is never used and respond
|
||||
"name1_5" as unique name which is used before, actually.
|
||||
|
||||
Args:
|
||||
name (str): Origin name which may have digit prefix.
|
||||
|
||||
Returns:
|
||||
A string represents real-name.
|
||||
"""
|
||||
pos = name.rfind("_")
|
||||
if pos == -1:
|
||||
return name
|
||||
digit = True
|
||||
for i in range(pos + 1, len(name)):
|
||||
if not name[i].isdigit():
|
||||
digit = False
|
||||
break
|
||||
if digit:
|
||||
return Namer._real_name(name[:pos])
|
||||
return name
|
||||
|
||||
def get_name(self, origin_name: str) -> str:
|
||||
"""
|
||||
Get unique name from 'origin_name'.
|
||||
|
||||
Args:
|
||||
origin_name (str): Origin name which may be duplicated.
|
||||
|
||||
Returns:
|
||||
A string represents unique-name.
|
||||
"""
|
||||
origin_name = Namer._real_name(origin_name)
|
||||
number = self._names.get(origin_name)
|
||||
if number is None:
|
||||
self._names[origin_name] = 1
|
||||
return origin_name
|
||||
self._names[origin_name] = number + 1
|
||||
return f"{origin_name}_{number}"
|
||||
|
||||
def add_name(self, name: str):
|
||||
"""
|
||||
Add a name to Namer which should be unique.
|
||||
|
||||
Args:
|
||||
name (str): A name should be unique in current namer.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If name is not unique in current namer.
|
||||
"""
|
||||
real_name = Namer._real_name(name)
|
||||
number = self._names.get(real_name)
|
||||
if number is not None:
|
||||
raise RuntimeError("name duplicated: ", name)
|
||||
self._names[name] = 1
|
||||
|
||||
|
||||
class TargetNamer(Namer):
|
||||
"""
|
||||
Used for unique-ing targets of node.
|
||||
"""
|
||||
def get_real_arg(self, origin_arg: str) -> str:
|
||||
"""
|
||||
Get real argument from 'origin_arg' because target of node which produces 'origin_arg' may be change for unique.
|
||||
|
||||
Args:
|
||||
origin_arg (str): Origin argument string which may be undefined cause to target unique-lize.
|
||||
|
||||
Returns:
|
||||
A string represents real argument name.
|
||||
"""
|
||||
num = self._names.get(origin_arg)
|
||||
if num is None or num == 1:
|
||||
return origin_arg
|
||||
return f"{origin_arg}_{num - 1}"
|
||||
|
||||
|
||||
class NodeNamer(Namer):
|
||||
"""
|
||||
Used for unique-ing node-name which is also used as field of init-function and key of global_vars
|
||||
"""
|
||||
|
||||
def get_name(self, node_or_name: Union[Node, str]) -> str:
|
||||
"""
|
||||
Override get_name in Namer class.
|
||||
Get unique node_name from 'origin_name' or an instance of node.
|
||||
|
||||
Args:
|
||||
node_or_name (Union[Node, str]): A string represents candidate node_name or an instance of node who require
|
||||
A unique node_name.
|
||||
|
||||
Returns:
|
||||
A string represents unique node_name.
|
||||
"""
|
||||
if isinstance(node_or_name, Node):
|
||||
origin_name = node_or_name.get_name()
|
||||
if origin_name is None or not origin_name:
|
||||
if node_or_name.get_node_type() == NodeType.CallCell:
|
||||
if not isinstance(node_or_name, Node):
|
||||
raise TypeError("node_or_name should be Node, got: ", type(node_or_name))
|
||||
targets = node_or_name.get_targets()
|
||||
# return node and head node will not call this method
|
||||
if not targets:
|
||||
raise RuntimeError("node should has at lease one target except return-node and head-node: ",
|
||||
node_or_name)
|
||||
origin_name = str(targets[0].value)
|
||||
elif node_or_name.get_node_type() == NodeType.Python:
|
||||
origin_name = node_or_name.get_instance().__name__
|
||||
elif node_or_name.get_node_type() == NodeType.Input:
|
||||
origin_name = "parameter"
|
||||
else:
|
||||
raise RuntimeError("Node type unsupported:", node_or_name.get_node_type())
|
||||
elif isinstance(node_or_name, str):
|
||||
if not node_or_name:
|
||||
raise RuntimeError("input node_name is empty.")
|
||||
origin_name = node_or_name
|
||||
else:
|
||||
raise RuntimeError("unexpected type of node_or_name: ", type(node_or_name))
|
||||
return super(NodeNamer, self).get_name(origin_name)
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Base class of parser."""
|
||||
import abc
|
||||
import ast
|
||||
|
||||
from .symbol_tree import SymbolTree
|
||||
|
||||
|
||||
class Parser(abc.ABC):
|
||||
"""
|
||||
DFS into a ast_node until add node into SymbolTree
|
||||
"""
|
||||
|
||||
def target(self) -> type:
|
||||
"""
|
||||
Get type of ast which could be accepted by current parser.
|
||||
|
||||
Returns:
|
||||
A type of ast.
|
||||
"""
|
||||
return type(None)
|
||||
|
||||
@abc.abstractmethod
|
||||
def process(self, stree: SymbolTree, node: ast.AST):
|
||||
"""
|
||||
Parse input ast node and add parse result into SymbolTree.
|
||||
|
||||
Args:
|
||||
stree (SymbolTree): current symbol_tree
|
||||
node (ast.AST): node who is tried to be parsed
|
||||
"""
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,86 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Parser register."""
|
||||
from typing import Optional
|
||||
|
||||
from .parser import Parser
|
||||
|
||||
|
||||
class ParserRegister:
|
||||
"""Parser register."""
|
||||
|
||||
def __init__(self):
|
||||
self._parsers: dict = {}
|
||||
|
||||
@classmethod
|
||||
def instance(cls) -> 'ParserRegister':
|
||||
"""
|
||||
Get singleton of ParserRegister.
|
||||
|
||||
Returns:
|
||||
An instance of ParserRegister.
|
||||
"""
|
||||
if not hasattr(ParserRegister, "_instance"):
|
||||
ParserRegister._instance = ParserRegister()
|
||||
return ParserRegister._instance
|
||||
|
||||
@staticmethod
|
||||
def reg_parser(parser: Parser):
|
||||
"""
|
||||
Register a 'parser' to current ParserRegister.
|
||||
|
||||
Args:
|
||||
parser (Parser): An instance of Parser to be registered.
|
||||
"""
|
||||
if isinstance(parser, Parser):
|
||||
ParserRegister.instance()._parsers[parser.target()] = parser
|
||||
|
||||
def get_parser(self, ast_type: type) -> Optional[Parser]:
|
||||
"""
|
||||
Get parser from current ParserRegister by type of ast.
|
||||
|
||||
Args:
|
||||
ast_type (type): An type of ast which want to be parsed.
|
||||
|
||||
Returns:
|
||||
An instance of Parser if there existing suitable parser in current ParserRegister else None.
|
||||
"""
|
||||
return self._parsers.get(ast_type)
|
||||
|
||||
def get_parsers(self) -> [Parser]:
|
||||
"""
|
||||
Get all parsers registered in current ParserRegister.
|
||||
|
||||
Returns:
|
||||
An list of instances of Parser.
|
||||
"""
|
||||
return self._parsers
|
||||
|
||||
|
||||
class ParserRegistry:
|
||||
"""Parser registry."""
|
||||
|
||||
def __init__(self, parser: Parser):
|
||||
ParserRegister.instance().reg_parser(parser)
|
||||
|
||||
|
||||
def reg_parser(parser: Parser):
|
||||
"""
|
||||
A global method for registering parser into ParserRegister singleton.
|
||||
|
||||
Args:
|
||||
parser (Parser): An instance of Parser to be registered.
|
||||
"""
|
||||
return ParserRegistry(parser)
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Parsers for resolve ast to SymbolTree
|
||||
"""
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Parse ast.arguments to input-node of SymbolTree."""
|
||||
import ast
|
||||
|
||||
from ..parser import Parser
|
||||
from ..parser_register import reg_parser
|
||||
from ..symbol_tree import SymbolTree
|
||||
|
||||
|
||||
class ArgumentsParser(Parser):
|
||||
"""Parse ast.arguments to input-node of SymbolTree."""
|
||||
|
||||
def target(self):
|
||||
"""Parse target type"""
|
||||
return ast.arguments
|
||||
|
||||
def process(self, stree: SymbolTree, node: ast.arguments):
|
||||
"""
|
||||
Parse ast.arguments and create input-node to stree.
|
||||
|
||||
Args:
|
||||
stree (SymbolTree): symbol tree under parsing.
|
||||
node (ast.arguments): argument node in construct.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Types of node.args elements are not ast.arg.
|
||||
"""
|
||||
if hasattr(node, "posonlyargs"):
|
||||
stree.try_append_python_node(node, node.posonlyargs)
|
||||
|
||||
for arg in node.args:
|
||||
if not isinstance(arg, ast.arg):
|
||||
raise RuntimeError("Unsupported ast type in arguments arg: ", arg)
|
||||
stree.append_input_node(arg.arg)
|
||||
|
||||
if hasattr(node, "vararg"):
|
||||
stree.try_append_python_node(node, node.vararg)
|
||||
if hasattr(node, "kwonlyargs"):
|
||||
stree.try_append_python_node(node, node.kwonlyargs)
|
||||
if hasattr(node, "kw_defaults"):
|
||||
stree.try_append_python_node(node, node.kw_defaults)
|
||||
if hasattr(node, "kwarg"):
|
||||
stree.try_append_python_node(node, node.kwarg)
|
||||
if hasattr(node, "defaults"):
|
||||
stree.try_append_python_node(node, node.defaults)
|
||||
|
||||
|
||||
g_arguments_parser = reg_parser(ArgumentsParser())
|
|
@ -0,0 +1,252 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Parse ast.Assign in construct function to node of SymbolTree."""
|
||||
import ast
|
||||
import astunparse
|
||||
|
||||
from mindspore import log as logger
|
||||
from ..symbol_tree import SymbolTree
|
||||
from ..node import Node, TreeNode
|
||||
from ..parser import Parser
|
||||
from ..parser_register import reg_parser
|
||||
from ..api.scoped_value import ScopedValue
|
||||
from ..symbol_tree_builder import SymbolTreeBuilder
|
||||
|
||||
|
||||
class AssignParser(Parser):
|
||||
"""Parse ast.Assign in construct function to node of SymbolTree."""
|
||||
|
||||
def target(self):
|
||||
"""Parse target type."""
|
||||
return ast.Assign
|
||||
|
||||
@staticmethod
|
||||
def _create_scopedvalue_from_tuple_ast(node: ast.Tuple) -> ScopedValue:
|
||||
"""
|
||||
Create ScopedValue from a tuple ast node.
|
||||
|
||||
Args:
|
||||
node (ast.Tuple): A tuple node.
|
||||
|
||||
Returns:
|
||||
An instance of ScopedValue.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Only support ast.Constant as elts of ast.Tuple.
|
||||
"""
|
||||
tuple_elts = node.elts
|
||||
tuple_values = []
|
||||
for tuple_elt in tuple_elts:
|
||||
if not isinstance(tuple_elt, ast.Constant):
|
||||
raise RuntimeError("Only support ast.Constant as elts of ast.Tuple.")
|
||||
tuple_values.append(tuple_elt.value)
|
||||
return ScopedValue.create_variable_value(tuple(tuple_values))
|
||||
|
||||
@staticmethod
|
||||
def _create_scopedvalue(node: ast.expr) -> ScopedValue:
|
||||
"""
|
||||
Create ScopedValue from an ast node.
|
||||
|
||||
Args:
|
||||
node (ast.expr): An ast node.
|
||||
|
||||
Returns:
|
||||
An instance of ScopedValue.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Value of target of ast.Assign should be an ast.Name when target is an ast.Attribute.
|
||||
RuntimeError: Type of input node is unsupported.
|
||||
"""
|
||||
if isinstance(node, ast.Name):
|
||||
return ScopedValue.create_naming_value(node.id)
|
||||
if isinstance(node, ast.Attribute):
|
||||
scope = node.value
|
||||
if not isinstance(scope, ast.Name):
|
||||
raise RuntimeError("value of target of ast.Assign should be a ast.Name when target is a ast.Attribute.")
|
||||
return ScopedValue.create_naming_value(node.attr, scope.id)
|
||||
if isinstance(node, ast.Tuple):
|
||||
return AssignParser._create_scopedvalue_from_tuple_ast(node)
|
||||
if isinstance(node, ast.Constant):
|
||||
return ScopedValue.create_variable_value(node.value)
|
||||
raise RuntimeError("Unsupported ast type to argument:", node)
|
||||
|
||||
@staticmethod
|
||||
def _get_func_name(ast_node: ast.Call) -> str:
|
||||
"""
|
||||
Get the func name from ast.Call.
|
||||
|
||||
Args:
|
||||
ast_node (ast.Call): Input ast.Call node.
|
||||
|
||||
Returns:
|
||||
Func name.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Func of input ast node is not ast.Name or ast.Attribute.
|
||||
"""
|
||||
func = ast_node.func
|
||||
if isinstance(func, ast.Name):
|
||||
return func.id
|
||||
if isinstance(func, ast.Attribute):
|
||||
return func.attr
|
||||
raise RuntimeError("FuncValue is should be Name or a Attribute:", astunparse.unparse(func))
|
||||
|
||||
@staticmethod
|
||||
def _get_func_scope(ast_node: ast.Call) -> str:
|
||||
"""
|
||||
Get the func scope from ast.Call.
|
||||
|
||||
Args:
|
||||
ast_node (ast.Call): Input ast.Call node.
|
||||
|
||||
Returns:
|
||||
Func scope.
|
||||
|
||||
Raises:
|
||||
RuntimeError: FuncValue is not an ast.Name when func is an ast.Attribute.
|
||||
RuntimeError: Func of input ast node is not ast.Name or ast.Attribute.
|
||||
"""
|
||||
func = ast_node.func
|
||||
if isinstance(func, ast.Name):
|
||||
return ""
|
||||
if isinstance(func, ast.Attribute):
|
||||
value = func.value
|
||||
if not isinstance(value, ast.Name):
|
||||
raise RuntimeError("FuncValue is should be Name:", ast.dump(func))
|
||||
return value.id
|
||||
raise RuntimeError("FuncValue is should be Name or a Attribute:", ast.dump(func))
|
||||
|
||||
@staticmethod
|
||||
def _get_symbol_object(symbol_name, origin_net):
|
||||
"""
|
||||
Get the func scope from ast.Call.
|
||||
|
||||
Args:
|
||||
symbol_name (str): Func name.
|
||||
origin_net ([nn.Cell]): Network instance.
|
||||
|
||||
Returns:
|
||||
Symbol Object.
|
||||
"""
|
||||
var_dict = origin_net.__dict__
|
||||
for key, value in var_dict["_cells"].items():
|
||||
if key == symbol_name:
|
||||
return value
|
||||
|
||||
for key, value in var_dict["_primitives"].items():
|
||||
if key == symbol_name:
|
||||
return value
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _create_kwargs(keywords: [ast.keyword]) -> {str, ScopedValue}:
|
||||
"""
|
||||
Transfer ast.Call keywords to a dict of ScopedValue when creating a symbol tree node.
|
||||
|
||||
Args:
|
||||
keywords ([ast.keyword]): Keywords of ast.Call node.
|
||||
|
||||
Returns:
|
||||
A dict of ScopedValue.
|
||||
"""
|
||||
results = {}
|
||||
for keyword in keywords:
|
||||
results[keyword.arg] = AssignParser._create_scopedvalue(keyword.value)
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _convert_ast_call_to_node(ast_node: ast.Call, father_ast_node: ast.Assign, stree: SymbolTree) -> Node:
|
||||
"""
|
||||
Convert ast.Call to a symbol tree node.
|
||||
|
||||
Args:
|
||||
ast_node ([ast.Call]): An ast.Call of assign node in construct.
|
||||
father_ast_node ([ast.Assign]): Assign node in construct.
|
||||
stree ([SymbolTree]): Symbol Tree under parsing.
|
||||
|
||||
Returns:
|
||||
An instance of Node in Symbol Tree.
|
||||
|
||||
Raises:
|
||||
RuntimeError: kwargs in construct function assign is unsupported.
|
||||
"""
|
||||
target = AssignParser._create_scopedvalue(father_ast_node.targets[0])
|
||||
func_name = AssignParser._get_func_name(ast_node)
|
||||
if func_name is None or func_name == "":
|
||||
raise RuntimeError("function name not exist")
|
||||
func_scope = AssignParser._get_func_scope(ast_node)
|
||||
func = ScopedValue.create_naming_value(func_name, func_scope)
|
||||
call_args = [AssignParser._create_scopedvalue(arg) for arg in ast_node.args]
|
||||
call_kwargs = AssignParser._create_kwargs(ast_node.keywords)
|
||||
if ast_node.keywords:
|
||||
raise RuntimeError("kwargs in construct function assign is unsupported.")
|
||||
|
||||
obj = AssignParser._get_symbol_object(func_name, stree.get_origin_network())
|
||||
# need check if node is a callmethod, like: x = len(x)
|
||||
# need check if node is a callprimitive, like: x = x * 5
|
||||
is_sub_tree = False
|
||||
if is_sub_tree:
|
||||
stb = SymbolTreeBuilder(obj)
|
||||
new_stree = stb.build()
|
||||
return TreeNode(new_stree, father_ast_node, [target], func, call_args, call_kwargs, func_name,
|
||||
new_stree.get_origin_network())
|
||||
return Node.create_call_cell(obj, father_ast_node, [target], func, call_args, call_kwargs, func_name)
|
||||
|
||||
def process(self, stree: SymbolTree, node: ast.Assign):
|
||||
"""
|
||||
Parse ast.Assign and create a node in symbol tree.
|
||||
Will create node when value of ast.Assign is in [ast.Call, ast.Name, ast.Constant, ast.Attribute].
|
||||
Will create python node when value of ast.Assign is in
|
||||
[ast.BinOp, ast.BoolOp, ast.Subscript, ast.List, ast.Tuple, ast.Dict].
|
||||
Other value types are not supported.
|
||||
|
||||
Args:
|
||||
stree ([SymbolTree]): Symbol Tree under parsing.
|
||||
node ([ast.Assign]): An ast.Assign node.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Only support one target in assign now.
|
||||
RuntimeError: Unsupported node type in construct function.
|
||||
"""
|
||||
targets = node.targets
|
||||
if len(targets) != 1:
|
||||
raise RuntimeError("Only support one target in assign now")
|
||||
value = node.value
|
||||
if isinstance(value, ast.Call):
|
||||
node_ = AssignParser._convert_ast_call_to_node(value, node, stree)
|
||||
stree.append_origin_field(node_)
|
||||
elif isinstance(value, (ast.BinOp, ast.BoolOp, ast.Subscript)):
|
||||
logger.warning(f"ops-call({astunparse.unparse(node)}) in assign will be supported in near feature, "
|
||||
f"ignored as a python node now")
|
||||
stree.try_append_python_node(node, node)
|
||||
elif isinstance(value, (ast.Name, ast.Constant, ast.Attribute, ast.Num, ast.NameConstant, ast.Bytes, ast.Str)):
|
||||
if isinstance(value, ast.Name):
|
||||
node_name = "name_assign"
|
||||
elif isinstance(value, ast.Constant):
|
||||
node_name = "constant_assign"
|
||||
else:
|
||||
node_name = "attribute_assign"
|
||||
target = AssignParser._create_scopedvalue(node.targets[0])
|
||||
call_args = [AssignParser._create_scopedvalue(value)]
|
||||
node_ = Node.create_call_pass_through_method(node, [target], call_args, {}, node_name)
|
||||
stree.append_origin_field(node_)
|
||||
elif isinstance(value, (ast.List, ast.Tuple, ast.Dict)):
|
||||
# add these as callmethod node if necessary
|
||||
stree.try_append_python_node(node, node)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported statement({astunparse.unparse(node)}) in construct function!")
|
||||
|
||||
|
||||
g_assign_parser = reg_parser(AssignParser())
|
|
@ -0,0 +1,170 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Parse ast.ClassDef which is subclass of Cell to SymbolTree."""
|
||||
import ast
|
||||
|
||||
from mindspore import log as logger
|
||||
from ..symbol_tree import SymbolTree
|
||||
from ..parser import Parser
|
||||
from ..parser_register import ParserRegister, reg_parser
|
||||
from ..api.scoped_value import ScopedValue
|
||||
from ..ast_modifier import AstModifier
|
||||
|
||||
|
||||
class ClassDefParser(Parser):
|
||||
"""Parse ast.ClassDef which is subclass of Cell to SymbolTree."""
|
||||
|
||||
def target(self):
|
||||
"""Parse target type"""
|
||||
return ast.ClassDef
|
||||
|
||||
@staticmethod
|
||||
def _process_init_func_ast(init_ast: ast.FunctionDef, ori_cls_name: str, opt_cls_name: str):
|
||||
"""Process init func"""
|
||||
super_index = ClassDefParser._modify_super_expr_of_init_func(init_ast, ori_cls_name, opt_cls_name)
|
||||
ClassDefParser._modify_arguments_of_init_func(init_ast)
|
||||
ClassDefParser._replace_ori_field_of_init_func(init_ast.body, super_index)
|
||||
ClassDefParser._insert_handler_to_init_func(init_ast, super_index)
|
||||
|
||||
@staticmethod
|
||||
def _modify_super_expr_of_init_func(ast_init_fn: ast.FunctionDef, ori_cls_name: str, opt_cls_name: str) -> int:
|
||||
"""Modify network name in super(XXnet).__init__()"""
|
||||
if not ast_init_fn.body:
|
||||
return -1
|
||||
super_index = -1
|
||||
super_call_args = None
|
||||
while True:
|
||||
super_index += 1
|
||||
expr = ast_init_fn.body[super_index]
|
||||
if not isinstance(expr, ast.Expr):
|
||||
continue
|
||||
expr_value = expr.value
|
||||
if not isinstance(expr_value, ast.Call):
|
||||
continue
|
||||
expr_value_func = expr_value.func
|
||||
if not isinstance(expr_value_func, ast.Attribute):
|
||||
continue
|
||||
expr_value_func_value = expr_value_func.value
|
||||
if expr_value_func.attr != "__init__" or not isinstance(expr_value_func_value, ast.Call):
|
||||
continue
|
||||
expr_value_func_value_func = expr_value_func_value.func
|
||||
if not isinstance(expr_value_func_value_func, ast.Name) or expr_value_func_value_func.id != "super":
|
||||
continue
|
||||
super_call_args = expr_value_func_value.args
|
||||
break
|
||||
if super_call_args is None or not isinstance(super_call_args, list) or len(super_call_args) != 2:
|
||||
return super_index
|
||||
super_call_arg = super_call_args[0]
|
||||
if super_call_arg.id != ori_cls_name:
|
||||
raise RuntimeError("super_call_arg.id should equal to ori_cls_name")
|
||||
super_call_arg.id = opt_cls_name
|
||||
return super_index
|
||||
|
||||
@staticmethod
|
||||
def _modify_arguments_of_init_func(ast_init_fn: ast.FunctionDef):
|
||||
"""Replace init function input parameters to self and global_vars."""
|
||||
arg_self = ast.arg(arg="self", annotation="")
|
||||
arg_global_vars = ast.arg(arg="global_vars", annotation="")
|
||||
ast_init_fn.args = ast.arguments(args=[arg_self, arg_global_vars], posonlyargs=[], kwonlyargs=[],
|
||||
kw_defaults=[], defaults=[], vararg=None, kwarg=None)
|
||||
ast.fix_missing_locations(ast_init_fn)
|
||||
|
||||
@staticmethod
|
||||
def _replace_ori_field_of_init_func(bodies: [], super_index: int):
|
||||
"""
|
||||
Replace original field in init func to self.XX = getattr(self._handler, "XX").
|
||||
Only keep following two kinds of ast nodes in bodies right now:
|
||||
1. Ast.If and test is self.XX.
|
||||
2. Ast.Assign and target is self.XX.
|
||||
|
||||
Args:
|
||||
bodies ([]): bodied of init ast.FunctionDef.
|
||||
super_index (int): index of super().__init__() in bodies.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Not support multi-targets in assign.
|
||||
RuntimeError: Only support target.value in [ast.Name] in assign node.
|
||||
"""
|
||||
body_index_to_be_deleted = []
|
||||
for body_index, body in enumerate(bodies):
|
||||
if body_index == super_index:
|
||||
continue # ignoring super.__init__()
|
||||
if isinstance(body, ast.If) and isinstance(body.test, ast.Attribute) \
|
||||
and isinstance(body.test.value, ast.Name) and body.test.value.id == 'self':
|
||||
ClassDefParser._replace_ori_field_of_init_func(body.body, -1)
|
||||
ClassDefParser._replace_ori_field_of_init_func(body.orelse, -1)
|
||||
continue
|
||||
if not isinstance(body, ast.Assign): # if not assign node, delete
|
||||
body_index_to_be_deleted.append(body_index)
|
||||
continue
|
||||
if len(body.targets) != 1:
|
||||
raise RuntimeError("Not support multi-targets in assign now!")
|
||||
target = body.targets[0]
|
||||
if not isinstance(target, ast.Attribute): # only keep class member
|
||||
body_index_to_be_deleted.append(body_index)
|
||||
continue
|
||||
if not isinstance(target.value, ast.Name):
|
||||
raise RuntimeError("Only support target.value in ast.Name now!")
|
||||
target_value: ast.Name = target.value
|
||||
if target_value.id != "self":
|
||||
body_index_to_be_deleted.append(body_index)
|
||||
continue
|
||||
field_name = target.attr
|
||||
body.value = ast.Call(ast.Name('getattr', ast.Load()),
|
||||
[ast.Attribute(ast.Name('self', ast.Load()), '_handler', ast.Load()),
|
||||
ast.Constant(value=field_name, kind=None)], [])
|
||||
for counter, index in enumerate(body_index_to_be_deleted):
|
||||
bodies.pop(index - counter)
|
||||
|
||||
@staticmethod
|
||||
def _insert_handler_to_init_func(ast_init_fn: ast.FunctionDef, super_index):
|
||||
"""Insert 'self._handler = global_vars.get('handler')' to init ast.FunctionDef.body"""
|
||||
if super_index == -1:
|
||||
super_index = 0
|
||||
AstModifier.insert_assign_to_function(ast_init_fn, [ScopedValue.create_naming_value("_handler", "self")],
|
||||
ScopedValue.create_naming_value("get", "global_vars"),
|
||||
[ScopedValue.create_variable_value("handler")], None,
|
||||
ast_init_fn.body[super_index], False)
|
||||
|
||||
def process(self, stree: SymbolTree, node: ast.ClassDef):
|
||||
"""
|
||||
Parse init and construct in ast.ClassDef.
|
||||
|
||||
Args:
|
||||
stree ([SymbolTree]): Symbol Tree under parsing.
|
||||
node ([ast.ClassDef]): An ast.ClassDef node.
|
||||
"""
|
||||
# change class name
|
||||
node.name = stree.get_opt_cls_name()
|
||||
|
||||
stree.set_class_ast(node)
|
||||
|
||||
for body in node.body:
|
||||
if isinstance(body, ast.FunctionDef):
|
||||
if body.name == "__init__":
|
||||
ClassDefParser._process_init_func_ast(body, stree.get_ori_cls_name(), stree.get_opt_cls_name())
|
||||
stree.set_init_func_ast(body)
|
||||
elif body.name == "construct":
|
||||
parser: Parser = ParserRegister.instance().get_parser(ast.FunctionDef)
|
||||
parser.process(stree, body)
|
||||
else:
|
||||
logger.warning(
|
||||
"Ignoring ast.FunctionDef in ast.ClassDef except __init__ and construct function: %s",
|
||||
body.name)
|
||||
else:
|
||||
logger.warning("Ignoring unsupported node(%s) in ast.ClassDef.", type(body).__name__)
|
||||
|
||||
|
||||
g_classdef_parser = reg_parser(ClassDefParser())
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Parse bodies of ast.FunctionDef which is construct function to nodes of SymbolTree."""
|
||||
import ast
|
||||
|
||||
from ..parser_register import ParserRegister, reg_parser
|
||||
from ..parser import Parser
|
||||
from ..symbol_tree import SymbolTree
|
||||
|
||||
|
||||
class FunctionDefParser(Parser):
|
||||
"""Parse bodies of ast.FunctionDef which is construct function to nodes of SymbolTree."""
|
||||
|
||||
def target(self):
|
||||
"""Parse target type"""
|
||||
return ast.FunctionDef
|
||||
|
||||
def process(self, stree: SymbolTree, node: ast.FunctionDef):
|
||||
"""Parse bodies of ast.FunctionDef which is construct function to nodes of SymbolTree."""
|
||||
stree.set_ast_root(node)
|
||||
# parse args as inputs of stree
|
||||
arguments: ast.arguments = node.args
|
||||
parser: Parser = ParserRegister.instance().get_parser(ast.arguments)
|
||||
parser.process(stree, arguments)
|
||||
|
||||
# parse body as node of stree
|
||||
for body in node.body:
|
||||
# avoid add dead code, so we need to break if return is added.
|
||||
parser: Parser = ParserRegister.instance().get_parser(type(body))
|
||||
if parser is None:
|
||||
stree.append_python_node(node, body)
|
||||
else:
|
||||
parser.process(stree, body)
|
||||
|
||||
if hasattr(node, "decorator_list"):
|
||||
stree.try_append_python_node(node, node.decorator_list)
|
||||
if hasattr(node, "returns"):
|
||||
stree.try_append_python_node(node, node.returns)
|
||||
|
||||
|
||||
g_functiondef_parser = reg_parser(FunctionDefParser())
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Parse ast.Module to SymbolTrees."""
|
||||
from typing import Any
|
||||
import os
|
||||
import ast
|
||||
import copy
|
||||
import inspect
|
||||
import astunparse
|
||||
|
||||
from mindspore import log as logger
|
||||
from ..symbol_tree import SymbolTree
|
||||
from ..parser import Parser
|
||||
from ..parser_register import ParserRegister, reg_parser
|
||||
|
||||
|
||||
class ClassFinder(ast.NodeVisitor):
|
||||
"""Find all ast.ClassDef in input ast node."""
|
||||
|
||||
def __init__(self):
|
||||
"""Keep all found ast.ClassDef in self._classes"""
|
||||
self._classes: [ast.ClassDef] = []
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> Any:
|
||||
"""Iterate over all nodes and save ast.ClassDef nodes."""
|
||||
self._classes.append(node)
|
||||
|
||||
def find_all_classes(self, node: ast.AST) -> [ast.ClassDef]:
|
||||
"""Interface of ClassFinder."""
|
||||
self.visit(node)
|
||||
return self._classes
|
||||
|
||||
|
||||
class ModuleParser(Parser):
|
||||
"""Parse ast.Module to SymbolTrees."""
|
||||
|
||||
def target(self):
|
||||
"""Parse target type"""
|
||||
return ast.Module
|
||||
|
||||
@staticmethod
|
||||
def _find_class(ast_node: ast.Module) -> ast.ClassDef:
|
||||
"""Find all ast.ClassDef in ast.Module, only support one ast.ClassDef in ast.Module now."""
|
||||
visitor = ClassFinder()
|
||||
classes = visitor.find_all_classes(ast_node)
|
||||
if not classes:
|
||||
raise RuntimeError("No class in module")
|
||||
if len(classes) > 1:
|
||||
raise RuntimeError("Multi-class in module is not supported now")
|
||||
return classes[0]
|
||||
|
||||
@staticmethod
|
||||
def get_import_node(ast_root):
|
||||
"""Iterate over ast_root and return all ast.Import nodes or ast.ImportFrom nodes in ast_root."""
|
||||
import_nodes = []
|
||||
|
||||
class GetImportNode(ast.NodeVisitor):
|
||||
"""Find all import nodes from input ast node."""
|
||||
def visit_Import(self, node: ast.Import) -> Any:
|
||||
"""Iterate over all nodes and save ast.Import nodes."""
|
||||
import_nodes.append(copy.deepcopy(node))
|
||||
return node
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
|
||||
"""Iterate over all nodes and save ast.ImportFrom nodes."""
|
||||
import_nodes.append(copy.deepcopy(node))
|
||||
return node
|
||||
|
||||
def get_node(self, input_ast):
|
||||
"""Interface of GetImportNode."""
|
||||
self.generic_visit(input_ast)
|
||||
return True
|
||||
|
||||
get_node_handler = GetImportNode()
|
||||
get_node_handler.get_node(ast_root)
|
||||
return import_nodes
|
||||
|
||||
@staticmethod
|
||||
def _add_import_to_module(module: ast.Module, origin_net):
|
||||
"""Insert two groups of import nodes to ast.Module, common ones and those from class definition file."""
|
||||
module.body.insert(0, ast.Import([ast.alias(name='mindspore', asname=None)]))
|
||||
module.body.insert(1, ast.ImportFrom(module='mindspore', names=[ast.alias(name='nn', asname=None)], level=0))
|
||||
module.body.insert(2, ast.ImportFrom(module='mindspore.nn', names=[ast.alias(name='Cell', asname=None)],
|
||||
level=0))
|
||||
origin_net_source_code_file = inspect.getfile(type(origin_net))
|
||||
if not os.path.exists(origin_net_source_code_file):
|
||||
raise RuntimeError("File ", origin_net_source_code_file, " not exist")
|
||||
try:
|
||||
with open(origin_net_source_code_file, "r") as f:
|
||||
source_code = f.read()
|
||||
import_nodes = ModuleParser.get_import_node(ast.parse(source_code))
|
||||
except RuntimeError:
|
||||
raise RuntimeError("get import nodes error")
|
||||
if import_nodes:
|
||||
for import_index, import_node in enumerate(import_nodes):
|
||||
module.body.insert(import_index + 3, import_node)
|
||||
ast.fix_missing_locations(module)
|
||||
|
||||
def process(self, stree: SymbolTree, node: ast.Module):
|
||||
"""Process ast.ClassDef nodes in ast.Module."""
|
||||
ModuleParser._add_import_to_module(node, stree.get_origin_network())
|
||||
class_ast = ModuleParser._find_class(node)
|
||||
stree.set_class_ast(class_ast)
|
||||
for body in node.body:
|
||||
if isinstance(body, ast.ClassDef):
|
||||
parser: Parser = ParserRegister.instance().get_parser(ast.ClassDef)
|
||||
parser.process(stree, body)
|
||||
else:
|
||||
logger.info(f"Ignoring unsupported node({astunparse.unparse(body)}) in ast.Module.")
|
||||
|
||||
|
||||
g_module_parser = reg_parser(ModuleParser())
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Parse ast.Return output-node of SymbolTree."""
|
||||
import ast
|
||||
|
||||
from ..symbol_tree import SymbolTree
|
||||
from ..node import Node
|
||||
from ..parser import Parser
|
||||
from ..parser_register import reg_parser
|
||||
|
||||
|
||||
class ReturnParser(Parser):
|
||||
"""Parse ast.Return output-node of SymbolTree."""
|
||||
|
||||
def target(self):
|
||||
"""Parse target type"""
|
||||
return ast.Return
|
||||
|
||||
def process(self, stree: SymbolTree, node: ast.Return):
|
||||
"""Parse ast.Return to output-node of SymbolTree."""
|
||||
return_value = node.value
|
||||
if not isinstance(return_value, ast.Name):
|
||||
raise RuntimeError("Only ast.Name as return value")
|
||||
node_return = Node.create_output_node(node, [return_value.id])
|
||||
stree.append_origin_field(node_return)
|
||||
|
||||
|
||||
g_return_parser = reg_parser(ReturnParser())
|
|
@ -0,0 +1,899 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""SymbolTree class define of Rewrite according to forward function of a network."""
|
||||
from typing import Optional, Union, Tuple
|
||||
import os
|
||||
import sys
|
||||
import ast
|
||||
import tempfile
|
||||
import astunparse
|
||||
|
||||
from mindspore.nn import Cell
|
||||
from mindspore import log as logger
|
||||
from .node import Node, TreeNode
|
||||
from .api.node_type import NodeType
|
||||
from .ast_modifier import AstModifier
|
||||
from .api.scoped_value import ScopedValue, ValueType
|
||||
from .symbol_tree_dumper import SymbolTreeDumper
|
||||
from .topological_manager import TopoManager
|
||||
from .namer import TargetNamer, NodeNamer
|
||||
|
||||
|
||||
class Position:
|
||||
"""Position indicates a source code position in one network."""
|
||||
|
||||
def __init__(self, symbol_tree, node, before_node: bool):
|
||||
"""
|
||||
Constructor of Position.
|
||||
Recommend to use class method of position rather than constructor of Position.
|
||||
|
||||
Args:
|
||||
symbol_tree (SymbolTree): A handler of SymbolTree indicated position in which SymbolTree.
|
||||
node (Node): A handler of Node indicated position is around which Node.
|
||||
before_node (bool): A bool indicated position is before or after the 'node'.
|
||||
"""
|
||||
self.symbol_tree = symbol_tree
|
||||
self.node = node
|
||||
self.before_node = before_node
|
||||
|
||||
@classmethod
|
||||
def create(cls, symbol_tree, node, before_node):
|
||||
"""
|
||||
Class method of Position. Return None when symbol_tree or node is None.
|
||||
|
||||
Args:
|
||||
symbol_tree: A handler of SymbolTree indicated position in which SymbolTree.
|
||||
node: A handler of Node indicated position is around which Node.
|
||||
before_node (bool): A bool indicated position is before or after the 'node'.
|
||||
|
||||
Returns:
|
||||
A Position.
|
||||
"""
|
||||
if symbol_tree is None or node is None:
|
||||
return None
|
||||
return Position(symbol_tree, node, before_node)
|
||||
|
||||
|
||||
class SymbolTree:
|
||||
"""A symbol-tree usually corresponding to forward method of a network."""
|
||||
|
||||
def __init__(self, origin_network: Cell, module_ast: ast.Module):
|
||||
"""
|
||||
Constructor of SymbolTree. Rewrite recommend using SymbolTreeBuilder to instantiate an instance of SymbolTree
|
||||
rather than invoking constructor of SymbolTree directly.
|
||||
|
||||
Args:
|
||||
origin_network (Cell): A handler to original network instance.
|
||||
module_ast (ast.Module): An instance of ast.AST represents ast node of original network.
|
||||
"""
|
||||
origin_network_key = "handler"
|
||||
# init unique-namers
|
||||
self._target_namer = TargetNamer()
|
||||
self._node_name_namer = NodeNamer()
|
||||
# name or node would use as name of field, so name of origin network handler field should be added into \
|
||||
# _node_name_namer.
|
||||
self._node_name_namer.add_name(origin_network_key)
|
||||
self._topo_mgr = TopoManager()
|
||||
|
||||
self._global_vars: {str, object} = {origin_network_key: origin_network}
|
||||
self._nodes: {str, Node} = {}
|
||||
# parameters of forward method
|
||||
self._inputs: [Node] = []
|
||||
self._ori_cls_name = type(origin_network).__name__
|
||||
self._opt_cls_name = self._ori_cls_name + "Opt"
|
||||
self._origin_network = origin_network
|
||||
self._module_ast: ast.Module = module_ast
|
||||
self._class_ast: Optional[ast.ClassDef] = None
|
||||
self._root_ast: Optional[ast.FunctionDef] = None
|
||||
self._init_func_ast: Optional[ast.FunctionDef] = None
|
||||
|
||||
# head node is always point to the first node(in source code order) of SymbolTree
|
||||
self._head = None
|
||||
# tail node is always point to the last node(in source code order) of SymbolTree
|
||||
self._tail = None
|
||||
self._return: Optional[Node] = None
|
||||
|
||||
def get_ori_cls_name(self) -> str:
|
||||
"""
|
||||
Get class name of original network.
|
||||
|
||||
Returns:
|
||||
A str represents class name of original network.
|
||||
"""
|
||||
return self._ori_cls_name
|
||||
|
||||
def get_opt_cls_name(self) -> str:
|
||||
"""
|
||||
Get class name of rewritten network.
|
||||
|
||||
Returns:
|
||||
A str represents class name of rewritten network.
|
||||
"""
|
||||
return self._opt_cls_name
|
||||
|
||||
def get_module_ast(self):
|
||||
"""
|
||||
Getter of `_module_ast`.
|
||||
|
||||
Returns:
|
||||
An instance of ast.AST represents ast node of corresponding module.
|
||||
"""
|
||||
return self._module_ast
|
||||
|
||||
def get_ast_root(self):
|
||||
"""
|
||||
Getter of `_root_ast`.
|
||||
|
||||
Returns:
|
||||
An instance of ast.AST represents ast node of corresponding forward method.
|
||||
"""
|
||||
return self._root_ast
|
||||
|
||||
def set_class_ast(self, ast_node: ast.ClassDef):
|
||||
"""
|
||||
Setter of `_class_ast`.
|
||||
|
||||
Args:
|
||||
ast_node (ast.ClassDef): An instance of ast.ClassDef represents ast node of corresponding network class.
|
||||
"""
|
||||
self._class_ast = ast_node
|
||||
|
||||
def set_init_func_ast(self, ast_node: ast.FunctionDef):
|
||||
"""
|
||||
Setter of `_init_func_ast`.
|
||||
|
||||
Args:
|
||||
ast_node (ast.FunctionDef): An instance of ast.FunctionDef represents ast node of init method of
|
||||
corresponding network class.
|
||||
"""
|
||||
self._init_func_ast = ast_node
|
||||
|
||||
def set_ast_root(self, ast_node: ast.FunctionDef):
|
||||
"""
|
||||
Setter of `_root_ast`.
|
||||
|
||||
Args:
|
||||
ast_node (ast.FunctionDef): An instance of ast.FunctionDef represents ast node of forward method of
|
||||
corresponding network class.
|
||||
"""
|
||||
self._root_ast = ast_node
|
||||
|
||||
def get_inputs(self):
|
||||
"""
|
||||
Getter of `_inputs` which represents parameters of current forward method.
|
||||
|
||||
Returns:
|
||||
A list of instance of Node whose node_type is NodeType.Input as input nodes.
|
||||
"""
|
||||
return self._inputs
|
||||
|
||||
def get_head_node(self):
|
||||
"""
|
||||
Getter of `_head` which represents the beginning node while iterating SymbolTree nodes.
|
||||
|
||||
Returns:
|
||||
An instance of node.
|
||||
"""
|
||||
return self._head
|
||||
|
||||
def get_return_node(self):
|
||||
"""
|
||||
Getter of `_return` which represents return statement of forward method of network.
|
||||
|
||||
Returns:
|
||||
An instance of node.
|
||||
"""
|
||||
return self._return
|
||||
|
||||
def get_origin_network(self):
|
||||
"""
|
||||
Getter of `_origin_network`.
|
||||
|
||||
Returns:
|
||||
An instance of Cell which represents original network.
|
||||
"""
|
||||
return self._origin_network
|
||||
|
||||
def nodes(self, unfold_subtree=True):
|
||||
"""
|
||||
Getter of nodes if current SymbolTree.
|
||||
|
||||
Args:
|
||||
unfold_subtree (bool): Need to iterate into sub-symbol-tree recursively.
|
||||
|
||||
Returns:
|
||||
A list of instance of Nodes.
|
||||
"""
|
||||
if unfold_subtree:
|
||||
nodes = []
|
||||
for _, v in self._nodes.items():
|
||||
if isinstance(v, TreeNode):
|
||||
nodes.extend(self.nodes(v.symbol_tree))
|
||||
else:
|
||||
nodes.append(v)
|
||||
return nodes
|
||||
return self._nodes.values()
|
||||
|
||||
def get_node(self, node_name: str) -> Optional[Node]:
|
||||
"""
|
||||
Get node of current symbol_tree by `node_name`.
|
||||
|
||||
Args:
|
||||
node_name (str): A str represents name of node as key of query.
|
||||
|
||||
Returns:
|
||||
An instance of Node if found else None.
|
||||
"""
|
||||
|
||||
return self._nodes.get(node_name)
|
||||
|
||||
def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
|
||||
if isinstance(node_or_name, Node):
|
||||
return self.get_node(node_or_name.get_name())
|
||||
if isinstance(node_or_name, str):
|
||||
return self.get_node(node_or_name)
|
||||
return None
|
||||
|
||||
def get_node_inputs(self, node_or_name: Union[Node, str]) -> [Node]:
|
||||
"""
|
||||
Getter of inputs in topological relation of current 'node_or_name'.
|
||||
|
||||
Args:
|
||||
node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
|
||||
|
||||
Returns:
|
||||
A list of instances of Node as input nodes if 'node_or_name' belong to current SymbolTree. An empty list if
|
||||
'node_or_name' not belong to current SymbolTree.
|
||||
"""
|
||||
|
||||
real_node: Optional[Node] = self._get_real_node(node_or_name)
|
||||
if real_node is None:
|
||||
logger.info("Node(%s) is not belong to current SymbolTree", node_or_name)
|
||||
return []
|
||||
return node_or_name.get_inputs()
|
||||
|
||||
def get_node_users(self, node_or_name: Union[Node, str]) -> [Tuple[Node, int]]:
|
||||
"""
|
||||
Getter of outputs in topological relation of current 'node_or_name'.
|
||||
|
||||
Args:
|
||||
node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
|
||||
|
||||
Returns:
|
||||
A list of instances of Node as output nodes if 'node_or_name' belong to current SymbolTree. An empty list if
|
||||
'node_or_name' not belong to current SymbolTree.
|
||||
"""
|
||||
|
||||
real_node: Optional[Node] = self._get_real_node(node_or_name)
|
||||
if real_node is None:
|
||||
logger.info("Node(%s) is not belong to current SymbolTree", node_or_name)
|
||||
return []
|
||||
return self._topo_mgr.get_node_users(node_or_name)
|
||||
|
||||
def before(self, node_or_name: Union[Node, str]) -> Position:
|
||||
"""
|
||||
Get insert position before 'node_or_name' in source code list.
|
||||
Consider using symbol_tree, node and before/after as position for sub-tree feature.
|
||||
|
||||
Note:
|
||||
Topological order is not determined here which is determined by arguments of node and updated by
|
||||
TopologicalManager automatically.
|
||||
|
||||
Args:
|
||||
node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
|
||||
|
||||
Returns:
|
||||
A Position represents an insert point.
|
||||
|
||||
Raises:
|
||||
AssertError: If 'node_or_name' is not a Node or a str
|
||||
RuntimeError: If 'node_or_name' is not belong to this SymbolTree or any sub-SymbolTree of current
|
||||
SymbolTree.
|
||||
"""
|
||||
|
||||
node = self._get_real_node(node_or_name)
|
||||
if node is None:
|
||||
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
|
||||
return Position.create(node.get_belong_symbol_tree(), node, True)
|
||||
|
||||
def after(self, node_or_name: Union[Node, str]) -> Position:
|
||||
"""
|
||||
Get insert position after 'node_or_name' in source code list.
|
||||
Consider using symbol_tree, node and before/after as position for sub-tree feature.
|
||||
|
||||
Note:
|
||||
Topological order is not determined here which is determined by arguments of node and updated by
|
||||
TopologicalManager automatically.
|
||||
|
||||
Args:
|
||||
node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
|
||||
|
||||
Returns:
|
||||
A Position represents an insert point.
|
||||
|
||||
Raises:
|
||||
AssertError: If 'node_or_name' is not a Node or a str
|
||||
RuntimeError: If 'node_or_name' is not belong to this SymbolTree or any sub-SymbolTree of current
|
||||
SymbolTree.
|
||||
"""
|
||||
|
||||
node = self._get_real_node(node_or_name)
|
||||
if node is None:
|
||||
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
|
||||
return Position.create(node.get_belong_symbol_tree(), node, False)
|
||||
|
||||
def insert_node(self, position: Position, node: Node, insert_to_ast: bool = True) -> Node:
|
||||
"""
|
||||
Insert a node into SymbolTree.
|
||||
Note:
|
||||
Name of node will be unique while inserting node into SymbolTree.
|
||||
|
||||
ValueType.CustomObjValue type arguments will be converted to ValueType.NamingValue and custom object will
|
||||
be saved in global_vars dict while inserting node into SymbolTree.
|
||||
|
||||
Targets of node will be unique while inserting node into SymbolTree.
|
||||
|
||||
A field instantiation statement will be added into "init" function of network class using node name as field
|
||||
name when `insert_to_ast` is True while inserting node into SymbolTree.
|
||||
|
||||
An assign statement represents invoking to this node will be added into forward function of network class
|
||||
corresponding to field-instantiation-statement when `insert_to_ast` is True while inserting node into
|
||||
SymbolTree.
|
||||
|
||||
Topological relation is updated and inputs of corresponding node is updated.
|
||||
|
||||
Args:
|
||||
position (Position): A Position indicates an insert position point.
|
||||
node (Node): An instance of node to be inserted in.
|
||||
insert_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
|
||||
True.
|
||||
|
||||
Returns:
|
||||
An instance of node which has been inserted into SymbolTree.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If 'position' is not in current SymbolTree.
|
||||
RuntimeError: If corresponding ast node is not an ast.Assign when 'insert_to_ast' is True.
|
||||
"""
|
||||
|
||||
# if position in current SymbolTree
|
||||
if position is not None and position.symbol_tree is not self:
|
||||
raise RuntimeError("Position is not in current SymbolTree:", position)
|
||||
# unique targets, name while insert node into symbol_tree
|
||||
node_name = self._node_name_namer.get_name(node)
|
||||
node.set_name(node_name)
|
||||
self._handle_custom_obj_in_normalized_args(node)
|
||||
# _unique_targets must called after _update_args_for_unique and _update_kwargs_for_unique
|
||||
self._unique_targets(node)
|
||||
self._insert_node(position, node)
|
||||
# update init-function-ast and construct-function-ast
|
||||
if insert_to_ast:
|
||||
node.set_func(ScopedValue.create_naming_value(node_name, "self"))
|
||||
node_ast = node.get_ast()
|
||||
if not isinstance(node_ast, ast.Assign):
|
||||
raise RuntimeError("Only support insert cell op now")
|
||||
AstModifier.insert_assign_to_function(self._init_func_ast,
|
||||
targets=[ScopedValue(ValueType.NamingValue, "self", node_name)],
|
||||
expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"),
|
||||
args=[ScopedValue(ValueType.StringValue, "", node_name)])
|
||||
AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
|
||||
None if position is None else position.node.get_ast(),
|
||||
position.before_node)
|
||||
self._global_vars[node_name] = node.get_instance()
|
||||
return node
|
||||
|
||||
def append_node(self, node: Node, append_to_ast: bool = True) -> Node:
|
||||
"""
|
||||
Append a node to SymbolTree.
|
||||
|
||||
Args:
|
||||
node (Node): An instance of node to be appended.
|
||||
append_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
|
||||
True.
|
||||
|
||||
Returns:
|
||||
An instance of node which has been appended to SymbolTree.
|
||||
"""
|
||||
return self.insert_node(Position.create(self, self._tail, False), node, append_to_ast)
|
||||
|
||||
def append_origin_field(self, node: Node) -> Node:
|
||||
"""
|
||||
Append an original field node to SymbolTree. An original field node represents a node created from existing
|
||||
statement in forward method, from existing ast node in ast of forward method, so ast node do not need to update
|
||||
while these nodes appending to SymbolTree.
|
||||
This method is called while building SymbolTree usually.
|
||||
|
||||
Args:
|
||||
node (Node): An instance of node to be appended.
|
||||
|
||||
Returns:
|
||||
An instance of node which has been appended to SymbolTree.
|
||||
"""
|
||||
self._update_args_kwargs_for_unique(node)
|
||||
if node.get_node_type() == NodeType.Output:
|
||||
self._return = node
|
||||
elif node.get_node_type() == NodeType.Input:
|
||||
self._inputs.append(node)
|
||||
return self.append_node(node, False)
|
||||
|
||||
def append_input_node(self, param_name: str, default: Optional[ScopedValue] = None):
|
||||
"""
|
||||
Append an input node to SymbolTree corresponding to parameter of forward method of network class.
|
||||
This method is called while building SymbolTree usually.
|
||||
|
||||
Args:
|
||||
param_name (str): A str represents name of parameter of forward method of network class.
|
||||
default (Optional[ScopedValue] ): A ScopedValue represents default value of parameter. Default is None which
|
||||
means parameter has no default value.
|
||||
|
||||
Returns:
|
||||
An instance of input node which has been appended to SymbolTree.
|
||||
"""
|
||||
if param_name == "self":
|
||||
return
|
||||
for input_node in self._inputs:
|
||||
targets = input_node.get_targets()
|
||||
if len(targets) != 1:
|
||||
raise RuntimeError("targets should have 1 elements")
|
||||
target: ScopedValue = targets[0]
|
||||
if target.type != ValueType.NamingValue:
|
||||
raise RuntimeError("target.type should equal to ValueType.NamingValue")
|
||||
if target.scope != "":
|
||||
raise RuntimeError("target.scope should be empty")
|
||||
exist_param = target.value
|
||||
if exist_param == param_name:
|
||||
raise RuntimeError("input duplicated:", param_name)
|
||||
input_node = Node.create_input_node(None, param_name, default, name=f"input_{param_name}")
|
||||
self.append_origin_field(input_node)
|
||||
|
||||
def try_append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST) -> Optional[Node]:
|
||||
"""
|
||||
Try appending a python node to SymbolTree if 'ast_node' is not None and 'ast_node' is not Empty if 'ast_node' is
|
||||
a list or a dict.
|
||||
This method is called while building SymbolTree usually.
|
||||
|
||||
Args:
|
||||
ast_scope (ast.AST): A ast node represents ast node of scope of node.
|
||||
ast_node (ast.AST): A ast node represents ast node.
|
||||
|
||||
Returns:
|
||||
An instance of python node if a new node has been appended to SymbolTree else None.
|
||||
"""
|
||||
if ast_node is None:
|
||||
return None
|
||||
if isinstance(ast_node, (list, dict)) and not ast_node:
|
||||
return None
|
||||
return self.append_python_node(ast_scope, ast_node)
|
||||
|
||||
def append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST) -> Node:
|
||||
"""
|
||||
Append a python node to SymbolTree.
|
||||
This method is called while building SymbolTree usually.
|
||||
|
||||
Args:
|
||||
ast_scope (ast.AST): A ast node represents ast node of scope of node.
|
||||
ast_node (ast.AST): A ast node represents ast node.
|
||||
|
||||
Returns:
|
||||
An instance of python node which has been appended to SymbolTree.
|
||||
"""
|
||||
logger.info("Ignoring unsupported node(%s) in %s.", type(ast_node).__name__, type(ast_scope).__name__)
|
||||
node_name = self._node_name_namer.get_name(type(ast_node).__name__)
|
||||
node = Node.create_python_node(ast_node, node_name)
|
||||
self._insert_node(Position.create(self, self._tail, True), node)
|
||||
return node
|
||||
|
||||
def set_output(self, return_value: str, index: int) -> Node:
|
||||
"""
|
||||
Update return value of return of forward method of network class.
|
||||
|
||||
Args:
|
||||
return_value (str): A str represents new return value.
|
||||
index (int): A int indicates which return value to be updated.
|
||||
|
||||
Returns:
|
||||
An instance of node represents return node after updated.
|
||||
"""
|
||||
if self._return is None:
|
||||
raise RuntimeError("SymbolTree has no output")
|
||||
self.set_node_arg(self._return, index, return_value)
|
||||
return self._return
|
||||
|
||||
def erase_node(self, node_or_name: Union[Node, str]) -> Node:
|
||||
"""
|
||||
Erase a node from SymbolTree.
|
||||
Note:
|
||||
If node is depended on by other node, RuntimeError will raise.
|
||||
|
||||
Topological relation is updated.
|
||||
|
||||
Args:
|
||||
node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
|
||||
|
||||
Returns:
|
||||
An instance of node which has been erased from SymbolTree.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If 'node_or_name' is not in current SymbolTree.
|
||||
RuntimeError: If erase corresponding ast node failed.
|
||||
"""
|
||||
|
||||
node = self._get_real_node(node_or_name)
|
||||
if node is None:
|
||||
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
|
||||
ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
|
||||
if not ret:
|
||||
raise RuntimeError("node not in function ast tree.")
|
||||
for key, value in self._nodes.items():
|
||||
if id(value) == id(node):
|
||||
self._nodes.pop(key)
|
||||
value.isolate()
|
||||
break
|
||||
self._topo_mgr.on_erase_node(node)
|
||||
return node
|
||||
|
||||
def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node:
|
||||
"""
|
||||
Insert a node-tree into SymbolTree.
|
||||
Note:
|
||||
Inputs of intra sub-tree nodes need to be welly set.
|
||||
|
||||
Inputs of inter sub-tree nodes will be updated by Rewrite automatically.
|
||||
|
||||
Args:
|
||||
position (Position): A Position indicates an insert position point.
|
||||
root (Node): An instance of node as root of node-tree to be inserted in.
|
||||
insert_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
|
||||
True.
|
||||
|
||||
Returns:
|
||||
An instance of node as root node of node-tree which has been inserted into SymbolTree.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If 'position' is not in current SymbolTree.
|
||||
"""
|
||||
|
||||
# if position not in current SymbolTree
|
||||
if position.symbol_tree is not self:
|
||||
raise RuntimeError("Position is not in current SymbolTree: ", position)
|
||||
|
||||
queue: [Node] = [root]
|
||||
todos: [] = []
|
||||
inputs_list: [] = []
|
||||
while queue:
|
||||
cur_node = queue.pop(0)
|
||||
if cur_node in todos:
|
||||
continue
|
||||
todos.append(cur_node)
|
||||
node_inputs = cur_node.get_inputs()
|
||||
inputs_list.append(node_inputs)
|
||||
for node_input in node_inputs:
|
||||
if node_input is not None:
|
||||
queue.append(node_input)
|
||||
todos.reverse()
|
||||
inputs_list.reverse()
|
||||
for index, todo in enumerate(todos):
|
||||
self.insert_node(position, todo, insert_to_ast)
|
||||
position = self.after(todo)
|
||||
# relink input of node
|
||||
original_inputs = inputs_list[index]
|
||||
for arg_idx, original_input in enumerate(original_inputs):
|
||||
if original_input is not None:
|
||||
self.set_node_arg_by_node(todo, arg_idx, original_input)
|
||||
return root
|
||||
|
||||
@staticmethod
|
||||
def _link_nodes_and_find_root(nodes: [Node]) -> Node:
|
||||
"""
|
||||
Find inputs for all nodes created by Replacement according to their targets and arguments.
|
||||
Find root node of all nodes created by Replacement. One and Only one root should be found.
|
||||
|
||||
Args:
|
||||
nodes ([Node]): A list of instance of Node created by Replacement.
|
||||
|
||||
Returns:
|
||||
An instance of Node represents root of input nodes.
|
||||
"""
|
||||
consumers: [ScopedValue] = []
|
||||
target_dict: {ScopedValue: Node} = {}
|
||||
for node in nodes:
|
||||
consumers.extend(node.get_args())
|
||||
for _, arg in node.get_kwargs():
|
||||
consumers.append(arg)
|
||||
for target in node.get_targets():
|
||||
if target_dict.get(target) is not None:
|
||||
raise RuntimeError("Target of node duplicated")
|
||||
target_dict[target] = node
|
||||
# find root node
|
||||
root = None
|
||||
for node in nodes:
|
||||
used = 0
|
||||
for target in node.get_targets():
|
||||
if target in consumers:
|
||||
used += 1
|
||||
if used == 0:
|
||||
if root is not None:
|
||||
raise RuntimeError("Replacement should only has one root")
|
||||
root = node
|
||||
if root is None:
|
||||
raise RuntimeError("No root node found in replacement nodes")
|
||||
# link node's input
|
||||
for node in nodes:
|
||||
inputs = []
|
||||
for _, arg in node.get_normalized_args().items():
|
||||
node_input: Node = target_dict.get(arg)
|
||||
if node_input is None:
|
||||
inputs.append(None)
|
||||
else:
|
||||
inputs.append(node_input)
|
||||
node.set_inputs(inputs)
|
||||
return root
|
||||
|
||||
def replace(self, old_node: Node, new_nodes: [Node]) -> Node:
|
||||
"""
|
||||
Replace an old_node with a node_tree. 'new_node' is the root node of the node_tree.
|
||||
Note:
|
||||
Rewrite will iterate all nodes linked to this root node and insert these nodes into symbol_tree.
|
||||
|
||||
Inputs of intra sub-tree nodes need to be welly set.
|
||||
|
||||
Inputs of inter sub-tree nodes will be updated by Rewrite automatically.
|
||||
|
||||
Args:
|
||||
old_node (Node): Node to be replaced.
|
||||
new_nodes ([Node]): Node tree to replace in.
|
||||
|
||||
Returns:
|
||||
An instance of Node represents root of node_tree been replaced in.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If 'old_node' is isolated.
|
||||
RuntimeError: If 'old_node' is not belong to current SymbolTree.
|
||||
"""
|
||||
|
||||
real_old_node = self._get_real_node(old_node)
|
||||
if real_old_node is None:
|
||||
raise RuntimeError("Old node is not belong to current SymbolTree:", old_node)
|
||||
# get position
|
||||
next_node: Node = old_node.get_next()
|
||||
prev_node: Node = old_node.get_prev()
|
||||
if prev_node is None and next_node is None:
|
||||
raise RuntimeError("Try replacing a isolated node: ", old_node)
|
||||
if next_node is None:
|
||||
position = self.after(prev_node)
|
||||
else:
|
||||
position = self.before(next_node)
|
||||
# insert node first, because targets of new_node is determined after insert
|
||||
new_tree_root = SymbolTree._link_nodes_and_find_root(new_nodes)
|
||||
new_node = self._insert_tree(position, new_tree_root)
|
||||
# use targets of insert tree to redirect edge
|
||||
users = self.get_node_users(old_node)
|
||||
if len(new_node.get_targets()) != 1:
|
||||
raise RuntimeError("targets of new_node should have 1 elements")
|
||||
for user in users:
|
||||
self.set_node_arg_by_node(user[0], user[1], new_node)
|
||||
# erase old_node after edge is redirected because node can be erased only when node is isolated topologically
|
||||
self.erase_node(old_node)
|
||||
return new_node
|
||||
|
||||
def set_node_arg(self, node: Union[Node, str], index: int, arg: Union[ScopedValue, str]):
|
||||
"""
|
||||
Set argument of 'node'.
|
||||
|
||||
Args:
|
||||
node (Union[Node, str]): Node to be modified. Can be a node or name of node.
|
||||
index (int): Indicate which input being modified.
|
||||
arg (Union[ScopedValue, str]): New argument to been set.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If 'node' is not belong to current SymbolTree.
|
||||
"""
|
||||
|
||||
real_node = self._get_real_node(node)
|
||||
if real_node is None:
|
||||
raise RuntimeError("Node is not belong to current SymbolTree: ", node)
|
||||
|
||||
new_arg, old_arg = node.set_arg(arg, index)
|
||||
self._topo_mgr.on_update_arg(node, index, old_arg, new_arg)
|
||||
|
||||
def set_node_arg_by_node(self, dst_node: Union[Node, str], arg_idx: int, src_node: Union[Node, str],
|
||||
out_idx: Optional[int] = None):
|
||||
"""
|
||||
Set argument of 'dst_node' by another Node.
|
||||
|
||||
Args:
|
||||
dst_node (Node): Node to be modified. Can be a node or name of node.
|
||||
arg_idx (int): Indicate which input being modified.
|
||||
src_node (Node): Node as new input. Can be a node or name of node.
|
||||
out_idx (Optional[int]): Indicate which output of 'src_node' as new input of 'dst_node'. Default is None
|
||||
which means use first output of 'node_to_link' as new input.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If 'dst_node' is not belong to current SymbolTree.
|
||||
RuntimeError: If 'src_node' is not belong to current SymbolTree.
|
||||
RuntimeError: If 'out_idx' is out of range.
|
||||
RuntimeError: If 'src_node' has multi-outputs while 'out_idx' is None or 'out_idx' is not offered.
|
||||
"""
|
||||
|
||||
real_dst_node = self._get_real_node(dst_node)
|
||||
if real_dst_node is None:
|
||||
raise RuntimeError("dst_node is not belong to current SymbolTree: ", dst_node)
|
||||
real_src_node = self._get_real_node(src_node)
|
||||
if real_src_node is None:
|
||||
raise RuntimeError("src_node is not belong to current SymbolTree: ", src_node)
|
||||
|
||||
targets = real_src_node.get_targets()
|
||||
if out_idx is None:
|
||||
if len(targets) != 1:
|
||||
raise RuntimeError("node should has one output when out_idx is not provided")
|
||||
out_idx = 0
|
||||
if out_idx >= len(targets):
|
||||
raise RuntimeError("out_idx out of range: ", out_idx)
|
||||
new_arg = targets[out_idx]
|
||||
self.set_node_arg(real_dst_node, arg_idx, new_arg)
|
||||
|
||||
def dump(self):
|
||||
"""Dump graph."""
|
||||
dump_st = SymbolTreeDumper(self)
|
||||
dump_st.dump()
|
||||
|
||||
def get_code(self) -> str:
|
||||
"""
|
||||
Get source code of modified network.
|
||||
|
||||
Returns:
|
||||
A str represents source code of modified network.
|
||||
"""
|
||||
ast.fix_missing_locations(self._module_ast)
|
||||
return astunparse.unparse(self._module_ast)
|
||||
|
||||
def get_network(self):
|
||||
"""
|
||||
Get modified network.
|
||||
|
||||
Returns:
|
||||
A network object.
|
||||
"""
|
||||
cls = self._get_cls_through_file()
|
||||
return cls(self._global_vars)
|
||||
|
||||
def _unique_targets(self, node: Node):
|
||||
"""
|
||||
Unique targets of node by _target_namer.
|
||||
|
||||
Args:
|
||||
node (Node): A Node whose targets to be uniqued.
|
||||
"""
|
||||
new_targets: [ScopedValue] = []
|
||||
if node.get_targets() is None:
|
||||
return
|
||||
for target in node.get_targets():
|
||||
if not isinstance(target, ScopedValue):
|
||||
raise TypeError("target should be ScopedValue, got: ", type(target))
|
||||
unique_target = self._target_namer.get_name(target.value)
|
||||
new_targets.append(ScopedValue.create_naming_value(unique_target, target.scope))
|
||||
node.set_targets(new_targets)
|
||||
|
||||
def _update_args_kwargs_for_unique(self, node: Node):
|
||||
"""
|
||||
Update arguments and keyword arguments of node because unique-ing of targets of other nodes.
|
||||
|
||||
Args:
|
||||
node (Node): A Node whose arguments and keyword arguments to be updated.
|
||||
"""
|
||||
result: {str: ScopedValue} = {}
|
||||
if node.get_normalized_args() is None:
|
||||
return
|
||||
for key, arg in node.get_normalized_args().items():
|
||||
if not isinstance(arg, ScopedValue):
|
||||
raise TypeError("arg should be ScopedValue, got: ", type(arg))
|
||||
if arg.type == ValueType.NamingValue:
|
||||
# unique name
|
||||
new_arg = ScopedValue(arg.type, arg.scope, self._target_namer.get_real_arg(arg.value))
|
||||
result[key] = new_arg
|
||||
else:
|
||||
result[key] = arg
|
||||
node.set_normalized_args(result)
|
||||
|
||||
def _add_node2nodes(self, node: Node):
|
||||
"""
|
||||
Add `node` to `_nodes` dict.
|
||||
|
||||
Args:
|
||||
node (Node): A Node to be added into `_nodes`.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If name of the node is duplicated.
|
||||
"""
|
||||
node_name = node.get_name()
|
||||
if self._nodes.get(node_name) is not None:
|
||||
raise RuntimeError("generated duplicated node name", node_name, self._nodes.get(node_name),
|
||||
node)
|
||||
self._nodes[node_name] = node
|
||||
|
||||
def _insert_node(self, position: Optional[Position], node: Node):
|
||||
"""
|
||||
Insert a node into SymbolTree.
|
||||
1. Add `node` to `_nodes`.
|
||||
2. Insert `node` to node list(source code order).
|
||||
3. Update topological relation and update inputs of `node`.
|
||||
|
||||
Args:
|
||||
position (Optional[Position]): Indicates node insert position. Position is None when inserting first node of
|
||||
SymbolTree.
|
||||
node (Node): A Node to be inserted into SymbolTree.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Position is None when _nodes of SymbolTree is not Empty. It means position can not be None
|
||||
unless inserting first node.
|
||||
"""
|
||||
if position is None:
|
||||
if self._nodes:
|
||||
raise RuntimeError("self._nodes should be empty")
|
||||
self._head = node
|
||||
else:
|
||||
if position.before_node:
|
||||
position.node.insert_before(node)
|
||||
else:
|
||||
position.node.insert_after(node)
|
||||
self._tail = node
|
||||
self._add_node2nodes(node)
|
||||
self._topo_mgr.on_insert_node(node)
|
||||
node.set_belong_symbol_tree(self)
|
||||
|
||||
def _handle_custom_obj_in_normalized_args(self, node: Node):
|
||||
"""
|
||||
Convert CustomObjValue type argument to NamingValue type argument by storing custom object in global_vars dict.
|
||||
|
||||
Args:
|
||||
node (Node): A Node whose arguments and keyword arguments to be handled.
|
||||
"""
|
||||
result: {str, ScopedValue} = {}
|
||||
for arg, value in node.get_normalized_args().items():
|
||||
if not isinstance(value, ScopedValue):
|
||||
raise TypeError("value should be ScopedValue, got: ", type(value))
|
||||
if value.type == ValueType.CustomObjValue:
|
||||
field = self._node_name_namer.get_name(f"var_{type(value.value).__name__}")
|
||||
self._global_vars[field] = value.value
|
||||
init_targets = [ScopedValue.create_naming_value(field, "self")]
|
||||
AstModifier.append_global_vars_expr_to_init(self._init_func_ast, init_targets, field)
|
||||
result[arg] = init_targets[0]
|
||||
else:
|
||||
result[arg] = value
|
||||
node.set_normalized_args(result)
|
||||
|
||||
def _get_cls_through_file(self):
|
||||
"""
|
||||
Load rewritten network class of current SymbolTree.
|
||||
1. Get source code of current SymbolTree.
|
||||
2. Saving source code to a tempfile.
|
||||
3. Import rewritten network class using "__import__" function.
|
||||
|
||||
Returns:
|
||||
A class handle.
|
||||
"""
|
||||
source = self.get_code()
|
||||
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.py')
|
||||
tmp_file.write(source.encode('utf8'))
|
||||
tmp_file.flush()
|
||||
tmp_file_name = tmp_file.name
|
||||
tmp_module_path, tmp_module_file = os.path.split(tmp_file_name)
|
||||
tmp_module_name = tmp_module_file[:-3]
|
||||
sys.path.append(tmp_module_path)
|
||||
tmp_module = __import__(tmp_module_name)
|
||||
network_cls = getattr(tmp_module, self._opt_cls_name)
|
||||
if network_cls is None:
|
||||
raise RuntimeError("Can not find network class:", self._opt_cls_name)
|
||||
return network_cls
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""SymbolTree builder."""
|
||||
from typing import Optional
|
||||
import ast
|
||||
import inspect
|
||||
|
||||
from mindspore.nn import Cell
|
||||
from .symbol_tree import SymbolTree
|
||||
from .parser_register import ParserRegister
|
||||
from .parser import Parser
|
||||
from .ast_transformers import FlattenRecursiveStmt
|
||||
|
||||
|
||||
class SymbolTreeBuilder:
|
||||
"""SymbolTree builder."""
|
||||
|
||||
def __init__(self, network: Cell):
|
||||
"""
|
||||
Constructor of SymbolTreeBuilder.
|
||||
|
||||
Args:
|
||||
network (Cell): An instance of Cell represents a network from which SymbolTree will be built.
|
||||
"""
|
||||
if not isinstance(network, Cell):
|
||||
raise RuntimeError("Only support network with Cell type now")
|
||||
self._origin_net = network
|
||||
network_str = inspect.getsource(type(network))
|
||||
self._ast_root: ast.Module = ast.parse(network_str)
|
||||
self._root_tree: Optional[SymbolTree] = None
|
||||
|
||||
@staticmethod
|
||||
def _ast_transform(ast_root: ast.AST) -> ast.AST:
|
||||
"""
|
||||
Optimize ast before parse.
|
||||
|
||||
Args:
|
||||
ast_root (ast.AST): An instance of ast to be optimized.
|
||||
|
||||
Returns:
|
||||
An instance of ast been optimized.
|
||||
"""
|
||||
transform_list = [FlattenRecursiveStmt()]
|
||||
for transformer in transform_list:
|
||||
ast_root = transformer.transform(ast_root)
|
||||
return ast_root
|
||||
|
||||
def build(self) -> SymbolTree:
|
||||
"""
|
||||
Build SymbolTree.
|
||||
|
||||
Returns:
|
||||
An instance of SymbolTree.
|
||||
"""
|
||||
self._ast_root = SymbolTreeBuilder._ast_transform(self._ast_root)
|
||||
if not isinstance(self._ast_root, ast.Module):
|
||||
raise RuntimeError("ast_root should be a ast.Module")
|
||||
self._root_tree: SymbolTree = SymbolTree(self._origin_net, self._ast_root)
|
||||
parser: Parser = ParserRegister.instance().get_parser(ast.Module)
|
||||
parser.process(self._root_tree, self._ast_root)
|
||||
return self._root_tree
|
|
@ -0,0 +1,144 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""SymbolTree dumper."""
|
||||
import inspect
|
||||
|
||||
from mindspore import log as logger
|
||||
from .node import Node
|
||||
from .api.node_type import NodeType
|
||||
from .api.scoped_value import ScopedValue, ValueType
|
||||
|
||||
|
||||
class SymbolTreeDumper:
|
||||
"""SymbolTree dumper."""
|
||||
|
||||
def __init__(self, symbol_tree):
|
||||
"""
|
||||
Constructor of SymbolTreeDumper.
|
||||
|
||||
Args:
|
||||
symbol_tree (SymbolTree): An instance of SymbolTree to be dumped.
|
||||
"""
|
||||
self._symbol_tree = symbol_tree
|
||||
self._dump_buffer = ""
|
||||
self._dump_key2index = {}
|
||||
|
||||
def _reset(self):
|
||||
"""Reset SymbolTreeDumper."""
|
||||
self._dump_buffer = ""
|
||||
self._dump_key2index = {}
|
||||
|
||||
def _dump_global_info(self):
|
||||
"""Dump global info of SymbolTree."""
|
||||
self._dump_buffer += f"#SymbolTree entry : @construct \n"
|
||||
|
||||
def _dump_inputs(self):
|
||||
"""Dump inputs of SymbolTree."""
|
||||
inputs = self._symbol_tree.get_inputs()
|
||||
self._dump_buffer += f"#Inputs num : {len(inputs)}\n"
|
||||
for single_input in inputs:
|
||||
targets = single_input.get_targets()
|
||||
if len(targets) != 1:
|
||||
raise RuntimeError("Only support one output per node now")
|
||||
target: ScopedValue = targets[0]
|
||||
if target.type != ValueType.NamingValue:
|
||||
raise RuntimeError("target.type should equal to ValueType.NamingValue")
|
||||
if target.scope != "":
|
||||
raise RuntimeError("target.scope should be empty")
|
||||
input_arg = target.value
|
||||
input_name = f"%input_{input_arg}"
|
||||
if input_arg in self._dump_key2index.keys():
|
||||
raise RuntimeError("input_arg duplicated: ", input_arg)
|
||||
self._dump_key2index[input_arg] = input_name
|
||||
self._dump_buffer += f"{input_name}\n"
|
||||
self._dump_buffer += f"\n"
|
||||
|
||||
def _dump_nodes(self):
|
||||
"""Dump nodes of SymbolTree."""
|
||||
self._dump_buffer += f"Symbol Tree @construct {{ \n"
|
||||
node_no = -1
|
||||
|
||||
node: Node = self._symbol_tree.get_head_node().get_next()
|
||||
while node is not None:
|
||||
if node.get_node_type() is NodeType.Output:
|
||||
self._dump_buffer += f" Return(%{node_no}) \n"
|
||||
self._dump_buffer += f" : (null) \n"
|
||||
self._dump_buffer += f" # In file {inspect.getfile(type(self._symbol_tree.get_origin_network()))}"
|
||||
|
||||
node = node.get_next()
|
||||
continue
|
||||
|
||||
node_no += 1
|
||||
self._dump_key2index[node.get_name()] = f"%{node_no}"
|
||||
|
||||
targets = node.get_targets()
|
||||
if not targets:
|
||||
targets = [None]
|
||||
op_type = node.get_instance_type()
|
||||
if hasattr(op_type, "__name__"):
|
||||
op_type_name = op_type.__name__
|
||||
else:
|
||||
if hasattr(type(op_type), "__name__"):
|
||||
op_type_name = type(op_type).__name__
|
||||
else:
|
||||
raise RuntimeError("op has no attr __name__")
|
||||
self._dump_buffer += f" %{node_no}({targets[0]}) = {op_type_name}"
|
||||
|
||||
args = node.get_normalized_args().values()
|
||||
if args:
|
||||
arg_str = f""
|
||||
for arg in args:
|
||||
if isinstance(arg, str):
|
||||
arg_name = arg
|
||||
elif isinstance(arg, ScopedValue):
|
||||
arg_name = arg.value
|
||||
else:
|
||||
raise RuntimeError(f"arg type {type(arg)} of {arg} is not supported now")
|
||||
|
||||
if arg_name in self._dump_key2index.keys():
|
||||
arg_str += f"{self._dump_key2index[arg_name]}, "
|
||||
else:
|
||||
logger.warning("arg not appears before")
|
||||
arg_str += f"{arg_name}, "
|
||||
self._dump_buffer += f"({arg_str[:-2]})"
|
||||
|
||||
self._dump_buffer += f"{{instance name: {node.get_name()}}}"
|
||||
|
||||
self._dump_buffer += f" attributes {{"
|
||||
attrs = node.get_attributes()
|
||||
if attrs:
|
||||
attrs_str = f""
|
||||
for attr in attrs:
|
||||
if not isinstance(attr, str):
|
||||
raise TypeError("attr should be str, got: ", type(attr))
|
||||
attrs_str += f"{attr}: {attrs[attr]}, "
|
||||
self._dump_buffer += attrs_str[:-2]
|
||||
self._dump_buffer += f"}}\n"
|
||||
|
||||
self._dump_buffer += f" : (null) -> (null)\n"
|
||||
cls_real_path = inspect.getfile(node.get_instance_type()) if node.get_instance() else None
|
||||
self._dump_buffer += f" # In file {cls_real_path}\n"
|
||||
self._dump_buffer += f" # In file {inspect.getfile(type(self._symbol_tree.get_origin_network()))}\n"
|
||||
|
||||
node = node.get_next()
|
||||
self._dump_buffer += f"}}\n"
|
||||
|
||||
def dump(self):
|
||||
"""Dump SymbolTree."""
|
||||
self._reset()
|
||||
self._dump_global_info()
|
||||
self._dump_inputs()
|
||||
self._dump_nodes()
|
||||
print(self._dump_buffer)
|
|
@ -0,0 +1,189 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""SymbolTree topological-relationship manager."""
|
||||
from typing import Tuple
|
||||
|
||||
from .api.scoped_value import ScopedValue
|
||||
from .node import Node
|
||||
|
||||
|
||||
class TopoManager:
|
||||
"""SymbolTree topological-relationship manager."""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Constructor of TopoManager.
|
||||
Init provider and consumer.
|
||||
Key of dict is an instance of ScopedValue.
|
||||
Value of dict is a tuple whose first is an instance of Node, whose second is an index.
|
||||
It means node's index th arg is argument
|
||||
"""
|
||||
self._target_provider: {ScopedValue, (Node, int)} = {}
|
||||
self._target_consumer: {ScopedValue, [(Node, int)]} = {}
|
||||
|
||||
def get_node_users(self, node: Node) -> [Tuple[Node, int]]:
|
||||
"""
|
||||
Get all nodes which depend on node corresponding to node_or_name.
|
||||
|
||||
Args:
|
||||
node (Node): An instance of node.
|
||||
|
||||
Returns:
|
||||
A list of nodes represents node users.
|
||||
"""
|
||||
targets = node.get_targets()
|
||||
results = []
|
||||
for target in targets:
|
||||
consumers = self._target_consumer.get(target)
|
||||
if consumers is None:
|
||||
continue
|
||||
results.extend(consumers)
|
||||
unique_results = []
|
||||
for result in results:
|
||||
if result not in unique_results:
|
||||
unique_results.append(result)
|
||||
return unique_results
|
||||
|
||||
def _add_consumer(self, product: ScopedValue, consumer: Node, index):
|
||||
"""
|
||||
Add a consumer to consumer dict.
|
||||
|
||||
Args:
|
||||
product (ScopedValue): An instance of ScopedValue represents product to be consumed.
|
||||
consumer (Node): An instance of Node represents consumer.
|
||||
index (int): A int represents which input of consumer is the product.
|
||||
"""
|
||||
consumers = self._target_consumer.get(product)
|
||||
if consumers is None:
|
||||
self._target_consumer[product] = [(consumer, index)]
|
||||
else:
|
||||
self._target_consumer[product].append((consumer, index))
|
||||
|
||||
def _erase_provider(self, product: ScopedValue):
|
||||
"""
|
||||
Erase a provider from provider dict.
|
||||
|
||||
Args:
|
||||
product (ScopedValue): An instance of ScopedValue represents product to be erased.
|
||||
"""
|
||||
if self._target_provider.get(product) is not None:
|
||||
self._target_provider.pop(product)
|
||||
|
||||
def _erase_consumer(self, product: ScopedValue, consumer: Node):
|
||||
"""
|
||||
Erase a consumer from consumer dict.
|
||||
|
||||
Args:
|
||||
product (ScopedValue): An instance of ScopedValue represents product whose consumer would be erased.
|
||||
consumer (Node): An instance of Node which would be erased as a consumer.
|
||||
"""
|
||||
consumers = self._target_consumer.get(product)
|
||||
if consumers is None:
|
||||
return
|
||||
for i in range(len(consumers) - 1, -1, -1):
|
||||
exist_ele = consumers[i]
|
||||
if id(exist_ele[0]) == id(consumer):
|
||||
consumers.pop(i)
|
||||
|
||||
def _update_node_inputs(self, node: Node) -> [Node]:
|
||||
"""
|
||||
Update inputs of node by current provider dict and consumer dict.
|
||||
|
||||
Args:
|
||||
node (Node): An instance of Node whose inputs will be updated.
|
||||
|
||||
Returns:
|
||||
A list of instance of nodes represents inputs of node.
|
||||
"""
|
||||
if node.get_normalized_args() is None:
|
||||
node.set_inputs([])
|
||||
return []
|
||||
inputs = []
|
||||
for arg in node.get_normalized_args().values():
|
||||
provider = self._target_provider.get(arg)
|
||||
# some arg of some node may be self.xxx which is not an output of another node
|
||||
if provider is not None:
|
||||
inputs.append(provider[0])
|
||||
node.set_inputs(inputs)
|
||||
return inputs
|
||||
|
||||
def on_insert_node(self, node: Node):
|
||||
"""
|
||||
Update provider dict and consumer dict while inserting node into SymbolTree and update inputs of node by updated
|
||||
provider dict and consumer dict.
|
||||
|
||||
Args:
|
||||
node (Node): An instance of Node which been inserted into SymbolTree.
|
||||
"""
|
||||
if node.get_targets() is not None:
|
||||
for i in range(0, len(node.get_targets())):
|
||||
target = node.get_targets()[i]
|
||||
if self._target_provider.get(target) is not None:
|
||||
raise RuntimeError("target duplicated:", target)
|
||||
self._target_provider[target] = (node, i)
|
||||
if node.get_normalized_args() is not None:
|
||||
for index, arg in enumerate(node.get_normalized_args().values()):
|
||||
self._add_consumer(arg, node, index)
|
||||
self._update_node_inputs(node)
|
||||
|
||||
def on_erase_node(self, node: Node):
|
||||
"""
|
||||
Update provider dict and consumer dict while erasing node from SymbolTree.
|
||||
|
||||
Args:
|
||||
node (Node): An instance of Node which been erased from SymbolTree.
|
||||
"""
|
||||
if node.get_targets() is not None:
|
||||
for target in node.get_targets():
|
||||
consumers = self._target_consumer.get(target)
|
||||
if consumers is not None and consumers:
|
||||
raise RuntimeError("Only support erase isolated node: ", node.get_name(), target)
|
||||
self._erase_provider(target)
|
||||
if node.get_normalized_args() is not None:
|
||||
for arg in node.get_normalized_args().values():
|
||||
self._erase_consumer(arg, node)
|
||||
# clear inputs of node rather than call _update_node_inputs because node is already erase from consumer dict
|
||||
node.set_inputs([])
|
||||
|
||||
def on_update_arg(self, node: Node, arg_idx: int, old_arg: ScopedValue, new_arg: ScopedValue):
|
||||
"""
|
||||
Update provider dict and consumer dict while updating argument of node and update inputs of node by updated
|
||||
provider dict and consumer dict.
|
||||
|
||||
Args:
|
||||
node (Node): An instance of Node whose arguments being updated.
|
||||
arg_idx (int): An int indicates which argument of node being updated.
|
||||
old_arg (ScopedValue): An instance of ScopedValue represents original argument.
|
||||
new_arg (ScopedValue): An instance of ScopedValue represents new argument.
|
||||
"""
|
||||
self._erase_consumer(old_arg, node)
|
||||
self._add_consumer(new_arg, node, arg_idx)
|
||||
self._update_node_inputs(node)
|
||||
|
||||
def dump(self, title=""):
|
||||
"""
|
||||
Dump topological relation.
|
||||
|
||||
Args:
|
||||
title (str): A string as a title will be printed before dumping topological relation.
|
||||
"""
|
||||
print(f"{title}------------------------------------------------------------------------------------")
|
||||
for k, v in self._target_provider.items():
|
||||
print(f"{v[0].get_name()} produces {k.value}")
|
||||
for k, v in self._target_consumer.items():
|
||||
print(f"{k.value} is consumed by: ")
|
||||
for ele in v:
|
||||
print(ele[0].get_name())
|
||||
print(f"-----------------------------------------------------------------------------------------")
|
|
@ -12,3 +12,5 @@ pycocotools >= 2.0.2 # for st test
|
|||
tables >= 3.6.1 # for st test
|
||||
easydict >= 1.9 # for st test
|
||||
psutil >= 5.7.0
|
||||
astunparse >= 0.0
|
||||
astpretty >= 0.0
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import ast
|
||||
import inspect
|
||||
|
||||
from mindspore.nn import Cell, Conv2d, BatchNorm2d, ReLU
|
||||
from mindspore.rewrite.ast_transformers.flatten_recursive_stmt import FlattenRecursiveStmt
|
||||
|
||||
|
||||
class Network(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = Conv2d(16, 16, 3)
|
||||
self.bn = BatchNorm2d(16)
|
||||
self.relu1 = ReLU()
|
||||
self.relu2 = ReLU()
|
||||
self.relu3 = ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv(x + 1)
|
||||
x = x + 1 * 5 + 4 / 2 + self.bn(x)
|
||||
self.relu1(x * 5)
|
||||
x = self.relu2(x + 1)
|
||||
x = True and x or x
|
||||
x = self.relu3(x)
|
||||
return x + 3
|
||||
|
||||
|
||||
def _get_ast():
|
||||
source = inspect.getsource(Network)
|
||||
return ast.parse(source)
|
||||
|
||||
|
||||
def test_flatten():
|
||||
"""
|
||||
Feature: Class FlattenRecursiveStmt.
|
||||
Description: Apply FlattenRecursiveStmt on a simple network.
|
||||
Expectation: Success.
|
||||
"""
|
||||
ast_node = _get_ast()
|
||||
frs = FlattenRecursiveStmt()
|
||||
frs.transform(ast_node)
|
||||
assert len(ast_node.body) == 1
|
||||
ast_class = ast_node.body[0]
|
||||
assert isinstance(ast_class, ast.ClassDef)
|
||||
assert len(ast_class.body) == 2
|
||||
ast_init_func = ast_class.body[0]
|
||||
assert isinstance(ast_init_func, ast.FunctionDef)
|
||||
assert len(ast_init_func.body) == 6
|
||||
ast_construct_func = ast_class.body[1]
|
||||
assert isinstance(ast_construct_func, ast.FunctionDef)
|
||||
assert len(ast_construct_func.body) == 17
|
|
@ -0,0 +1,141 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.rewrite import SymbolTree, ScopedValue, ValueType, Node
|
||||
from mindspore.common.initializer import Normal
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
|
||||
|
||||
class SimpleNet(nn.Cell):
|
||||
def __init__(self, num_class=10, num_channel=1):
|
||||
super(SimpleNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
|
||||
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
|
||||
self.relu = nn.ReLU()
|
||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.flatten = nn.Flatten()
|
||||
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
|
||||
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
|
||||
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
|
||||
self.var = 10
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = x
|
||||
y = self.var
|
||||
y = y * 5
|
||||
y = y and True
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.conv2(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.flatten(x)
|
||||
x = self.relu(self.fc1(x))
|
||||
x = self.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
class MyCell(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Dense(5, 16)
|
||||
|
||||
def construct(self, x, y):
|
||||
x = self.conv(x)
|
||||
x = mindspore.ops.Add()(x, y)
|
||||
return x
|
||||
|
||||
|
||||
def add_conv_before_flatten(stree: SymbolTree):
|
||||
new_conv_node = None
|
||||
for node in stree.nodes():
|
||||
if node.get_instance_type() == mindspore.nn.Flatten:
|
||||
position = stree.before(node)
|
||||
new_conv = nn.Conv2d(16, 16, 3)
|
||||
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
|
||||
args=[ScopedValue.create_naming_value('self_max_po')])
|
||||
stree.insert(position, new_conv_node)
|
||||
break
|
||||
if new_conv_node is not None:
|
||||
for node in stree.nodes():
|
||||
if node.get_instance_type() == mindspore.nn.Flatten:
|
||||
inputs = node.get_inputs()
|
||||
assert len(inputs) == 1
|
||||
new_conv_node.set_arg_by_node(0, inputs[0])
|
||||
|
||||
|
||||
def add_my_cell_after_x_12(stree: SymbolTree):
|
||||
for node in stree.nodes():
|
||||
targets = node.get_targets()
|
||||
if targets is None:
|
||||
continue
|
||||
assert targets[0].type == ValueType.NamingValue
|
||||
target = str(targets[0])
|
||||
if target == "x_12":
|
||||
position = stree.after(node)
|
||||
custom_cell = MyCell()
|
||||
bias = Tensor(1, mindspore.int32)
|
||||
new_custom_node = Node.create_call_cell(custom_cell, targets=['nx2'],
|
||||
args=[ScopedValue.create_naming_value('nx3'),
|
||||
ScopedValue.create_variable_value(bias)], name='my_cell')
|
||||
stree.insert(position, new_custom_node)
|
||||
new_custom_node.set_arg(0, "x_12")
|
||||
break
|
||||
|
||||
|
||||
def erase_node_x_11(stree: SymbolTree):
|
||||
return_node = None
|
||||
for node in stree.nodes():
|
||||
if node.get_targets() is None:
|
||||
return_node = node
|
||||
break
|
||||
assert return_node is not None
|
||||
for node in stree.nodes():
|
||||
targets = node.get_targets()
|
||||
if targets is None:
|
||||
continue
|
||||
assert targets[0].type == ValueType.NamingValue
|
||||
target = str(targets[0])
|
||||
if target == "x_11":
|
||||
stree.set_output(0, "x_10")
|
||||
stree.erase_node(node)
|
||||
break
|
||||
|
||||
|
||||
def transform(stree: SymbolTree):
|
||||
add_conv_before_flatten(stree)
|
||||
add_my_cell_after_x_12(stree)
|
||||
erase_node_x_11(stree)
|
||||
|
||||
|
||||
def test_simple_net():
|
||||
"""
|
||||
Feature: Module rewrite.
|
||||
Description: Resolve a simple network by rewrite and do some transform on it.
|
||||
Expectation: Result of rewrite can be compiled.
|
||||
"""
|
||||
net = SimpleNet(10)
|
||||
stree = SymbolTree(net)
|
||||
transform(stree)
|
||||
print("------------------------------------ keys of global_vars: ",
|
||||
getattr(stree.get_handler(), "_global_vars").keys())
|
||||
net_opt = stree.get_network()
|
||||
data_in = Tensor(np.ones([1, 1, 32, 32]), mindspore.float32)
|
||||
_cell_graph_executor.compile(net_opt, data_in)
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import ast
|
||||
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.rewrite import ScopedValue
|
||||
from mindspore.rewrite.node import Node
|
||||
|
||||
|
||||
class FakeCell(Cell):
|
||||
def construct(self, input1, input2, cool_boy=None):
|
||||
return input1 + input2 + cool_boy
|
||||
|
||||
|
||||
class FakeCell2(Cell):
|
||||
def construct(self, a, b, d, e, *args, f=6, **kwargs):
|
||||
return a + b + d + e + sum(args) + f + sum(kwargs.values())
|
||||
|
||||
|
||||
class FakeCell3(Cell):
|
||||
def construct(self, a, b, *args, f=6, h=7, **kwargs):
|
||||
return a + b + f + h + sum(args) + sum(kwargs.values())
|
||||
|
||||
|
||||
def test_create_by_cell():
|
||||
"""
|
||||
Feature: Python api create_call_cell of Node of Rewrite.
|
||||
Description: Call create_call_cell to create a node.
|
||||
Expectation: Success.
|
||||
"""
|
||||
node = Node.create_call_cell(FakeCell(), None, ['x'], 'new_conv',
|
||||
[ScopedValue.create_naming_value('x'), ScopedValue.create_variable_value(1)],
|
||||
{"cool_boy": ScopedValue.create_naming_value('Naroto')}, 'new_conv')
|
||||
assert node._args_num == 2
|
||||
assert node._kwargs_num == 1
|
||||
assert node._normalized_args_keys == ["input1", "input2", "cool_boy"]
|
||||
|
||||
assert node._normalized_args == {
|
||||
"input1": ScopedValue.create_naming_value('x'),
|
||||
"input2": ScopedValue.create_variable_value(1),
|
||||
"cool_boy": ScopedValue.create_naming_value('Naroto')
|
||||
}
|
||||
|
||||
ast_node: ast.Assign = node.get_ast()
|
||||
assign_value: ast.Call = ast_node.value
|
||||
args_ast = assign_value.args
|
||||
keywords_ast = assign_value.keywords
|
||||
assert len(args_ast) == 2
|
||||
assert len(keywords_ast) == 1
|
||||
assert keywords_ast[0].arg == "cool_boy"
|
||||
assert isinstance(args_ast[0], ast.Name)
|
||||
assert args_ast[0].id == "x"
|
||||
assert isinstance(args_ast[1], ast.Constant)
|
||||
assert args_ast[1].value == 1
|
||||
keyword_value_3 = keywords_ast[0].value
|
||||
assert isinstance(keyword_value_3, ast.Name)
|
||||
assert keyword_value_3.id == "Naroto"
|
||||
|
||||
node.set_arg(ScopedValue.create_variable_value(2), 1)
|
||||
assert isinstance(node.get_normalized_args().get("input2"), ScopedValue)
|
||||
assert node.get_normalized_args().get("input2").value == 2
|
||||
ast_node: ast.Assign = node.get_ast()
|
||||
assign_value: ast.Call = ast_node.value
|
||||
args_ast = assign_value.args
|
||||
assert args_ast[1].value == 2
|
||||
|
||||
args = node.get_args()
|
||||
assert args == [ScopedValue.create_naming_value('x'), ScopedValue.create_variable_value(2)]
|
||||
kwargs = node.get_kwargs()
|
||||
assert kwargs == {"cool_boy": ScopedValue.create_naming_value('Naroto')}
|
||||
|
||||
|
||||
def test_create_by_cell2():
|
||||
"""
|
||||
Feature: Python api create_call_cell of Node of Rewrite.
|
||||
Description: Call create_call_cell to create a node.
|
||||
Expectation: Success.
|
||||
"""
|
||||
node = Node.create_call_cell(FakeCell2(), None, ['x'], 'new_conv',
|
||||
[ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x"),
|
||||
ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x"),
|
||||
ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x")],
|
||||
{"cool_boy": ScopedValue.create_naming_value('Naroto')}, 'new_conv')
|
||||
assert node.get_normalized_args() == {
|
||||
"a": ScopedValue.create_naming_value('x'),
|
||||
"b": ScopedValue.create_naming_value('x'),
|
||||
"d": ScopedValue.create_naming_value('x'),
|
||||
"e": ScopedValue.create_naming_value('x'),
|
||||
"args_4": ScopedValue.create_naming_value('x'),
|
||||
"args_5": ScopedValue.create_naming_value('x'),
|
||||
"cool_boy": ScopedValue.create_naming_value('Naroto'),
|
||||
}
|
||||
|
||||
|
||||
def test_create_by_cell3():
|
||||
"""
|
||||
Feature: Python api create_call_cell of Node of Rewrite.
|
||||
Description: Call create_call_cell to create a node.
|
||||
Expectation: Success.
|
||||
"""
|
||||
node = Node.create_call_cell(FakeCell3(), None, ['x'], 'new_conv',
|
||||
[ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x"),
|
||||
ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x")],
|
||||
{"h": ScopedValue.create_naming_value(1), "f": ScopedValue.create_naming_value(2),
|
||||
"cool_boy": ScopedValue.create_naming_value('Naroto')}, 'new_conv')
|
||||
assert node.get_normalized_args() == {
|
||||
"a": ScopedValue.create_naming_value('x'),
|
||||
"b": ScopedValue.create_naming_value('x'),
|
||||
"args_2": ScopedValue.create_naming_value('x'),
|
||||
"args_3": ScopedValue.create_naming_value('x'),
|
||||
"f": ScopedValue.create_naming_value(2),
|
||||
"h": ScopedValue.create_naming_value(1),
|
||||
"cool_boy": ScopedValue.create_naming_value('Naroto'),
|
||||
}
|
|
@ -0,0 +1,665 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import ast
|
||||
from collections import OrderedDict
|
||||
|
||||
from mindspore.nn import Cell, Conv2d, BatchNorm2d, ReLU
|
||||
from mindspore.ops import Add, AddN
|
||||
from mindspore.rewrite import ScopedValue, Node, SymbolTree
|
||||
from mindspore.rewrite import PatternEngine, PatternNode, Replacement, VarNode
|
||||
|
||||
|
||||
def test_tree_pattern_match():
|
||||
"""
|
||||
Feature: Python api PatternEngine.
|
||||
Description: Construct a tree PatternEngine and apply it on a SymbolTree, check SymbolTree after PatternEngine
|
||||
applied.
|
||||
Expectation: Success.
|
||||
"""
|
||||
assert True
|
||||
|
||||
|
||||
def test_leak_pattern_match():
|
||||
"""
|
||||
Feature: Python api PatternEngine.
|
||||
Description: Construct a leaked tree PatternEngine and apply it on a SymbolTree, check SymbolTree after
|
||||
PatternEngine applied.
|
||||
Expectation: Failure.
|
||||
"""
|
||||
assert True
|
||||
|
||||
|
||||
class ChainNetwork(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = Conv2d(16, 16, 3)
|
||||
self.bn = BatchNorm2d(16)
|
||||
self.relu1 = ReLU()
|
||||
self.relu2 = ReLU()
|
||||
self.relu3 = ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu1(x)
|
||||
x = self.relu2(x)
|
||||
x = self.relu3(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_one_to_one_pattern():
|
||||
"""
|
||||
Feature: Python api PatternEngine.
|
||||
Description: Construct a one-to-one PatternEngine and apply it on a SymbolTree, check SymbolTree after PatternEngine
|
||||
applied.
|
||||
Expectation: Success.
|
||||
"""
|
||||
|
||||
class BnReplacement(Replacement):
|
||||
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
|
||||
assert is_chain_pattern
|
||||
assert pattern.type() == BatchNorm2d
|
||||
bn_node: Node = matched.get(pattern.name())
|
||||
assert bn_node is not None
|
||||
|
||||
conv = Conv2d(16, 16, 3)
|
||||
conv_node = Node.create_call_cell(conv, ['x1'], bn_node.get_args(), bn_node.get_kwargs())
|
||||
return [conv_node]
|
||||
|
||||
class BnReplace(PatternEngine):
|
||||
def __init__(self):
|
||||
super().__init__([BatchNorm2d], BnReplacement())
|
||||
|
||||
net = ChainNetwork()
|
||||
stree = SymbolTree(net)
|
||||
conv = stree.get_node("conv")
|
||||
bn = stree.get_node("bn")
|
||||
relu1 = stree.get_node("relu1")
|
||||
construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
|
||||
assert conv is not None
|
||||
assert bn is not None
|
||||
assert relu1 is not None
|
||||
assert len(construct_ast.body) == 6
|
||||
assert len(stree.nodes()) == 7
|
||||
|
||||
bn_replace = BnReplace()
|
||||
bn_replace.apply(stree)
|
||||
|
||||
assert len(construct_ast.body) == 6
|
||||
assert len(stree.nodes()) == 7
|
||||
conv = stree.get_node("conv")
|
||||
bn = stree.get_node("bn")
|
||||
relu1 = stree.get_node("relu1")
|
||||
new_conv = stree.get_node("x1")
|
||||
assert conv is not None
|
||||
assert bn is None
|
||||
assert relu1 is not None
|
||||
assert new_conv is not None
|
||||
|
||||
# check conv topological order
|
||||
assert len(conv.get_users()) == 1
|
||||
assert conv.get_users()[0] == new_conv
|
||||
# check new_conv topological order
|
||||
assert len(new_conv.get_inputs()) == 1
|
||||
assert new_conv.get_inputs()[0] == conv
|
||||
assert len(new_conv.get_users()) == 1
|
||||
assert new_conv.get_users()[0] == relu1
|
||||
# check source code order
|
||||
assert getattr(conv.get_handler(), "_next") == new_conv.get_handler()
|
||||
assert getattr(new_conv.get_handler(), "_next") == relu1.get_handler()
|
||||
assert getattr(relu1.get_handler(), "_prev") == new_conv.get_handler()
|
||||
assert getattr(new_conv.get_handler(), "_prev") == conv.get_handler()
|
||||
# # check arg edge
|
||||
assert len(conv.get_targets()) == 1
|
||||
assert len(new_conv.get_args()) == 1
|
||||
assert conv.get_targets()[0] == new_conv.get_args()[0]
|
||||
assert len(new_conv.get_targets()) == 1
|
||||
assert len(relu1.get_args()) == 1
|
||||
assert new_conv.get_targets()[0] == relu1.get_args()[0]
|
||||
|
||||
|
||||
def test_one_to_multi_chain_pattern():
|
||||
"""
|
||||
Feature: Python api PatternEngine.
|
||||
Description: Construct a one-to-multi PatternEngine and apply it on a SymbolTree, check SymbolTree after
|
||||
PatternEngine applied.
|
||||
Expectation: Success.
|
||||
"""
|
||||
|
||||
class BnReplacement(Replacement):
|
||||
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
|
||||
assert is_chain_pattern
|
||||
assert pattern.type() == BatchNorm2d
|
||||
bn_node: Node = matched.get(pattern.name())
|
||||
assert bn_node is not None
|
||||
|
||||
# Replacement should ensure target is unique in result
|
||||
# Replacement should ensure args and kwargs are well set by topological relation
|
||||
conv1 = Conv2d(16, 16, 3)
|
||||
conv_node1 = Node.create_call_cell(conv1, ['x1'], bn_node.get_args(), bn_node.get_kwargs())
|
||||
conv2 = Conv2d(16, 16, 5)
|
||||
conv_node2 = Node.create_call_cell(conv2, ['x2'], [ScopedValue.create_naming_value('x1')])
|
||||
return [conv_node1, conv_node2]
|
||||
|
||||
class BnReplace(PatternEngine):
|
||||
def __init__(self):
|
||||
super().__init__([BatchNorm2d], BnReplacement())
|
||||
|
||||
net = ChainNetwork()
|
||||
stree = SymbolTree(net)
|
||||
conv = stree.get_node("conv")
|
||||
bn = stree.get_node("bn")
|
||||
relu1 = stree.get_node("relu1")
|
||||
construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
|
||||
assert conv is not None
|
||||
assert bn is not None
|
||||
assert relu1 is not None
|
||||
assert len(construct_ast.body) == 6
|
||||
assert len(stree.nodes()) == 7
|
||||
|
||||
bn_replace = BnReplace()
|
||||
bn_replace.apply(stree)
|
||||
|
||||
assert len(construct_ast.body) == 7
|
||||
assert len(stree.nodes()) == 8
|
||||
conv = stree.get_node("conv")
|
||||
bn = stree.get_node("bn")
|
||||
relu1 = stree.get_node("relu1")
|
||||
new_conv1 = stree.get_node("x1")
|
||||
new_conv2 = stree.get_node("x2")
|
||||
assert conv is not None
|
||||
assert bn is None
|
||||
assert relu1 is not None
|
||||
assert new_conv1 is not None
|
||||
assert new_conv2 is not None
|
||||
|
||||
# check conv topological order
|
||||
assert len(conv.get_users()) == 1
|
||||
assert conv.get_users()[0] == new_conv1
|
||||
# check new_conv1 topological order
|
||||
assert len(new_conv1.get_inputs()) == 1
|
||||
assert new_conv1.get_inputs()[0] == conv
|
||||
assert len(new_conv1.get_users()) == 1
|
||||
assert new_conv1.get_users()[0] == new_conv2
|
||||
# check new_conv2 topological order
|
||||
assert len(new_conv2.get_inputs()) == 1
|
||||
assert new_conv2.get_inputs()[0] == new_conv1
|
||||
assert len(new_conv2.get_users()) == 1
|
||||
assert new_conv2.get_users()[0] == relu1
|
||||
# check source code order
|
||||
assert getattr(conv.get_handler(), "_next") == new_conv1.get_handler()
|
||||
assert getattr(new_conv1.get_handler(), "_next") == new_conv2.get_handler()
|
||||
assert getattr(new_conv2.get_handler(), "_next") == relu1.get_handler()
|
||||
assert getattr(relu1.get_handler(), "_prev") == new_conv2.get_handler()
|
||||
assert getattr(new_conv2.get_handler(), "_prev") == new_conv1.get_handler()
|
||||
assert getattr(new_conv1.get_handler(), "_prev") == conv.get_handler()
|
||||
# check arg edge
|
||||
assert len(conv.get_targets()) == 1
|
||||
assert len(new_conv1.get_args()) == 1
|
||||
assert conv.get_targets()[0] == new_conv1.get_args()[0]
|
||||
|
||||
assert len(new_conv1.get_targets()) == 1
|
||||
assert len(new_conv2.get_args()) == 1
|
||||
assert new_conv1.get_targets()[0] == new_conv2.get_args()[0]
|
||||
|
||||
assert len(new_conv2.get_targets()) == 1
|
||||
assert len(relu1.get_args()) == 1
|
||||
assert new_conv2.get_targets()[0] == relu1.get_args()[0]
|
||||
|
||||
|
||||
class TreeNetwork(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = Conv2d(16, 16, 3)
|
||||
self.conv2 = Conv2d(16, 16, 5)
|
||||
self.add = Add()
|
||||
self.relu = ReLU()
|
||||
self.relu1 = ReLU()
|
||||
self.relu2 = ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.conv1(x)
|
||||
x2 = self.conv2(x)
|
||||
x = self.add(x1, x2)
|
||||
x = self.relu(x)
|
||||
x1 = self.relu1(x)
|
||||
x2 = self.relu2(x)
|
||||
x = self.add(x1, x2)
|
||||
return x
|
||||
|
||||
|
||||
def test_tree_pattern():
|
||||
"""
|
||||
Feature: Python api PatternEngine.
|
||||
Description: Construct a multi-to-multi PatternEngine and apply it on a SymbolTree, check SymbolTree after
|
||||
PatternEngine applied.
|
||||
Expectation: Success.
|
||||
"""
|
||||
|
||||
class AddReluReplacement(Replacement):
|
||||
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
|
||||
assert is_chain_pattern
|
||||
assert pattern.type() == ReLU
|
||||
relu_node: Node = matched.get(pattern.name())
|
||||
assert relu_node is not None
|
||||
assert len(pattern.get_inputs()) == 1
|
||||
add_pattern = pattern.get_inputs()[0]
|
||||
assert add_pattern.type() == Add
|
||||
add_node: Node = matched.get(add_pattern.name())
|
||||
assert add_node is not None
|
||||
assert not add_pattern.get_inputs()
|
||||
|
||||
# can not use add_node here
|
||||
new_add1 = Add()
|
||||
new_add1_node = Node.create_call_cell(new_add1, ['new_add_1'], add_node.get_args(), add_node.get_kwargs())
|
||||
new_relu1 = ReLU()
|
||||
new_relu1_node = Node.create_call_cell(new_relu1, ['new_relu_1'],
|
||||
[ScopedValue.create_naming_value('new_add_1')])
|
||||
new_relu2 = ReLU()
|
||||
new_relu2_node = Node.create_call_cell(new_relu2, ['new_relu_2'],
|
||||
[ScopedValue.create_naming_value('new_add_1')])
|
||||
new_add2 = Add()
|
||||
new_add2_node = Node.create_call_cell(new_add2, ['new_add_2'],
|
||||
[ScopedValue.create_naming_value('new_relu_1'),
|
||||
ScopedValue.create_naming_value('new_relu_2')])
|
||||
return [new_add1_node, new_relu1_node, new_relu2_node, new_add2_node]
|
||||
|
||||
class AddReluPattern(PatternEngine):
|
||||
def __init__(self):
|
||||
super().__init__([Add, ReLU], AddReluReplacement())
|
||||
|
||||
net = TreeNetwork()
|
||||
stree = SymbolTree(net)
|
||||
conv1 = stree.get_node("conv1")
|
||||
conv2 = stree.get_node("conv2")
|
||||
add = stree.get_node("add")
|
||||
relu = stree.get_node("relu")
|
||||
relu1 = stree.get_node("relu1")
|
||||
relu2 = stree.get_node("relu2")
|
||||
assert conv1 is not None
|
||||
assert conv2 is not None
|
||||
assert add is not None
|
||||
assert relu is not None
|
||||
assert relu1 is not None
|
||||
assert relu2 is not None
|
||||
construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
|
||||
assert len(construct_ast.body) == 8
|
||||
assert len(stree.nodes()) == 9
|
||||
|
||||
add_relu_pattern = AddReluPattern()
|
||||
add_relu_pattern.apply(stree)
|
||||
|
||||
assert len(construct_ast.body) == 10
|
||||
assert len(stree.nodes()) == 11
|
||||
conv1 = stree.get_node("conv1")
|
||||
conv2 = stree.get_node("conv2")
|
||||
add = stree.get_node("add")
|
||||
relu = stree.get_node("relu")
|
||||
relu1 = stree.get_node("relu1")
|
||||
relu2 = stree.get_node("relu2")
|
||||
new_add = stree.get_node("new_add")
|
||||
new_relu = stree.get_node("new_relu")
|
||||
new_relu_1 = stree.get_node("new_relu_1")
|
||||
new_add_1 = stree.get_node("new_add_1")
|
||||
|
||||
assert conv1 is not None
|
||||
assert conv2 is not None
|
||||
assert add is None
|
||||
assert relu is None
|
||||
assert relu1 is not None
|
||||
assert relu2 is not None
|
||||
assert new_add is not None
|
||||
assert new_relu is not None
|
||||
assert new_relu_1 is not None
|
||||
assert new_add_1 is not None
|
||||
|
||||
# check conv1 topological order
|
||||
assert len(conv1.get_users()) == 1
|
||||
assert conv1.get_users()[0] == new_add
|
||||
# check conv2 topological order
|
||||
assert len(conv2.get_users()) == 1
|
||||
assert conv2.get_users()[0] == new_add
|
||||
# check new_add topological order
|
||||
assert len(new_add.get_inputs()) == 2
|
||||
assert new_add.get_inputs()[0] == conv1
|
||||
assert new_add.get_inputs()[1] == conv2
|
||||
assert len(new_add.get_users()) == 2
|
||||
assert new_add.get_users()[0] == new_relu
|
||||
assert new_add.get_users()[1] == new_relu_1
|
||||
# check new_relu topological order
|
||||
assert len(new_relu.get_inputs()) == 1
|
||||
assert new_relu.get_inputs()[0] == new_add
|
||||
assert len(new_relu.get_users()) == 1
|
||||
assert new_relu.get_users()[0] == new_add_1
|
||||
# check new_relu_1 topological order
|
||||
assert len(new_relu_1.get_inputs()) == 1
|
||||
assert new_relu_1.get_inputs()[0] == new_add
|
||||
assert len(new_relu_1.get_users()) == 1
|
||||
assert new_relu_1.get_users()[0] == new_add_1
|
||||
# check new_add_1 topological order
|
||||
assert len(new_add_1.get_inputs()) == 2
|
||||
assert new_add_1.get_inputs()[0] == new_relu_1
|
||||
assert new_add_1.get_inputs()[1] == new_relu
|
||||
assert len(new_add_1.get_users()) == 2
|
||||
assert new_add_1.get_users()[0] == relu1
|
||||
assert new_add_1.get_users()[1] == relu2
|
||||
# check source code order
|
||||
assert getattr(conv1.get_handler(), "_next") == conv2.get_handler()
|
||||
assert getattr(conv2.get_handler(), "_next") == new_add.get_handler()
|
||||
assert getattr(new_add.get_handler(), "_next") == new_relu.get_handler()
|
||||
assert getattr(new_relu.get_handler(), "_next") == new_relu_1.get_handler()
|
||||
assert getattr(new_relu_1.get_handler(), "_next") == new_add_1.get_handler()
|
||||
assert getattr(new_add_1.get_handler(), "_next") == relu1.get_handler()
|
||||
assert getattr(relu1.get_handler(), "_prev") == new_add_1.get_handler()
|
||||
assert getattr(new_add_1.get_handler(), "_prev") == new_relu_1.get_handler()
|
||||
assert getattr(new_relu_1.get_handler(), "_prev") == new_relu.get_handler()
|
||||
assert getattr(new_relu.get_handler(), "_prev") == new_add.get_handler()
|
||||
assert getattr(new_add.get_handler(), "_prev") == conv2.get_handler()
|
||||
assert getattr(conv2.get_handler(), "_prev") == conv1.get_handler()
|
||||
# check arg edge
|
||||
assert len(conv1.get_targets()) == 1
|
||||
assert len(conv2.get_targets()) == 1
|
||||
assert len(new_add.get_args()) == 2
|
||||
assert conv1.get_targets()[0] == new_add.get_args()[0]
|
||||
assert conv2.get_targets()[0] == new_add.get_args()[1]
|
||||
|
||||
assert len(new_add.get_targets()) == 1
|
||||
assert len(new_relu.get_args()) == 1
|
||||
assert len(new_relu_1.get_args()) == 1
|
||||
assert new_add.get_targets()[0] == new_relu.get_args()[0]
|
||||
assert new_add.get_targets()[0] == new_relu_1.get_args()[0]
|
||||
|
||||
assert len(new_relu.get_targets()) == 1
|
||||
assert len(new_relu_1.get_targets()) == 1
|
||||
assert len(new_add_1.get_args()) == 2
|
||||
assert new_relu.get_targets()[0] == new_add_1.get_args()[1]
|
||||
assert new_relu_1.get_targets()[0] == new_add_1.get_args()[0]
|
||||
|
||||
assert len(new_add_1.get_targets()) == 1
|
||||
assert len(relu1.get_args()) == 1
|
||||
assert len(relu2.get_args()) == 1
|
||||
assert new_add_1.get_targets()[0] == relu1.get_args()[0]
|
||||
assert new_add_1.get_targets()[0] == relu2.get_args()[0]
|
||||
|
||||
|
||||
class TreeNetwork2(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = Conv2d(16, 16, 1)
|
||||
self.conv2 = Conv2d(16, 16, 3)
|
||||
self.add1 = AddN()
|
||||
self.add2 = AddN()
|
||||
self.relu = ReLU()
|
||||
|
||||
def construct(self, x, y, z):
|
||||
x = self.conv1(x)
|
||||
y = self.conv2(y)
|
||||
z = self.add1(x, y, z)
|
||||
z = self.add2(x, y, z)
|
||||
z = self.relu(z)
|
||||
return z
|
||||
|
||||
|
||||
class MultiInputPattern(PatternEngine):
|
||||
class MultiInputReplacement(Replacement):
|
||||
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
|
||||
assert not is_chain_pattern
|
||||
assert pattern.type() == AddN
|
||||
addn2_node: Node = matched.get(pattern.name())
|
||||
assert addn2_node is not None
|
||||
assert len(pattern.get_inputs()) == 3
|
||||
conv1_pn = pattern.get_inputs()[0]
|
||||
conv2_pn = pattern.get_inputs()[1]
|
||||
addn1_pn = pattern.get_inputs()[2]
|
||||
assert conv1_pn.type() == Conv2d
|
||||
assert conv2_pn.type() == Conv2d
|
||||
assert addn1_pn.type() == AddN
|
||||
conv1_node: Node = matched.get(conv1_pn.name())
|
||||
conv2_node: Node = matched.get(conv2_pn.name())
|
||||
addn1_node: Node = matched.get(addn1_pn.name())
|
||||
assert conv1_node is not None
|
||||
assert conv2_node is not None
|
||||
assert addn1_node is not None
|
||||
assert len(conv1_node.get_inputs()) == 1
|
||||
assert len(conv2_node.get_inputs()) == 1
|
||||
assert len(addn1_node.get_inputs()) == 3
|
||||
arg1 = conv1_node.get_args()[0]
|
||||
arg2 = conv2_node.get_args()[0]
|
||||
arg3 = addn1_node.get_args()[2]
|
||||
|
||||
# can not use add_node here
|
||||
new_add1 = Add()
|
||||
new_add1_node = Node.create_call_cell(new_add1, ['new_add1'], [arg1, arg2])
|
||||
new_add2 = Add()
|
||||
new_add2_node = Node.create_call_cell(new_add2, ['new_add2'], [ScopedValue.create_naming_value('new_add1'),
|
||||
arg3])
|
||||
return [new_add1_node, new_add2_node]
|
||||
|
||||
def __init__(self):
|
||||
conv1_pn = PatternNode("conv1", Conv2d)
|
||||
conv2_pn = PatternNode("conv2", Conv2d)
|
||||
addn1_pn = PatternNode("addn1", AddN)
|
||||
addn2_pn = PatternNode("addn2", AddN)
|
||||
conv1_pn.set_inputs([VarNode()])
|
||||
conv2_pn.set_inputs([VarNode()])
|
||||
addn1_pn.set_inputs([conv1_pn, conv2_pn, VarNode()])
|
||||
addn2_pn.set_inputs([conv1_pn, conv2_pn, addn1_pn])
|
||||
super().__init__(addn2_pn, MultiInputPattern.MultiInputReplacement())
|
||||
|
||||
|
||||
def test_multi_input_to_multi_pattern_tree_pattern():
|
||||
"""
|
||||
Feature: Python api PatternEngine.
|
||||
Description: Construct a multi-to-multi PatternEngine and apply it on a SymbolTree, check SymbolTree after
|
||||
PatternEngine applied.
|
||||
Expectation: Success.
|
||||
"""
|
||||
|
||||
net = TreeNetwork2()
|
||||
stree = SymbolTree(net)
|
||||
conv1 = stree.get_node("conv1")
|
||||
conv2 = stree.get_node("conv2")
|
||||
add1 = stree.get_node("add1")
|
||||
add2 = stree.get_node("add2")
|
||||
relu = stree.get_node("relu")
|
||||
assert conv1 is not None
|
||||
assert conv2 is not None
|
||||
assert add1 is not None
|
||||
assert add2 is not None
|
||||
assert relu is not None
|
||||
construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
|
||||
assert len(construct_ast.body) == 6
|
||||
assert len(stree.nodes()) == 9
|
||||
|
||||
multi_input_pattern = MultiInputPattern()
|
||||
multi_input_pattern.apply(stree)
|
||||
|
||||
assert len(construct_ast.body) == 4
|
||||
assert len(stree.nodes()) == 7
|
||||
conv1 = stree.get_node("conv1")
|
||||
conv2 = stree.get_node("conv2")
|
||||
add1 = stree.get_node("add1")
|
||||
add2 = stree.get_node("add2")
|
||||
relu = stree.get_node("relu")
|
||||
new_add1 = stree.get_node("new_add1")
|
||||
new_add2 = stree.get_node("new_add2")
|
||||
inputx = stree.get_node("input_x")
|
||||
inputy = stree.get_node("input_y")
|
||||
inputz = stree.get_node("input_z")
|
||||
assert conv1 is None
|
||||
assert conv2 is None
|
||||
assert add1 is None
|
||||
assert add2 is None
|
||||
assert relu is not None
|
||||
assert new_add1 is not None
|
||||
assert new_add2 is not None
|
||||
assert inputx is not None
|
||||
assert inputy is not None
|
||||
assert inputz is not None
|
||||
|
||||
# check inputx topological order
|
||||
assert len(inputx.get_users()) == 1
|
||||
assert inputx.get_users()[0] == new_add1
|
||||
# check inputy topological order
|
||||
assert len(inputy.get_users()) == 1
|
||||
assert inputy.get_users()[0] == new_add1
|
||||
# check inputz topological order
|
||||
assert len(inputz.get_users()) == 1
|
||||
assert inputz.get_users()[0] == new_add2
|
||||
# check new_add1 topological order
|
||||
assert len(new_add1.get_inputs()) == 2
|
||||
assert new_add1.get_inputs()[0] == inputx
|
||||
assert new_add1.get_inputs()[1] == inputy
|
||||
assert len(new_add1.get_users()) == 1
|
||||
assert new_add1.get_users()[0] == new_add2
|
||||
# check new_add2 topological order
|
||||
assert len(new_add2.get_inputs()) == 2
|
||||
assert new_add2.get_inputs()[0] == new_add1
|
||||
assert new_add2.get_inputs()[1] == inputz
|
||||
assert len(new_add2.get_users()) == 1
|
||||
assert new_add2.get_users()[0] == relu
|
||||
# check relu topological order
|
||||
assert len(relu.get_inputs()) == 1
|
||||
assert relu.get_inputs()[0] == new_add2
|
||||
# check source code order
|
||||
assert getattr(inputz.get_handler(), "_next") == new_add1.get_handler()
|
||||
assert getattr(new_add1.get_handler(), "_next") == new_add2.get_handler()
|
||||
assert getattr(new_add2.get_handler(), "_next") == relu.get_handler()
|
||||
assert getattr(relu.get_handler(), "_prev") == new_add2.get_handler()
|
||||
assert getattr(new_add2.get_handler(), "_prev") == new_add1.get_handler()
|
||||
assert getattr(new_add1.get_handler(), "_prev") == inputz.get_handler()
|
||||
# check arg edge
|
||||
assert len(inputx.get_targets()) == 1
|
||||
assert len(inputy.get_targets()) == 1
|
||||
assert len(new_add1.get_args()) == 2
|
||||
assert inputx.get_targets()[0] == new_add1.get_args()[0]
|
||||
assert inputy.get_targets()[0] == new_add1.get_args()[1]
|
||||
|
||||
assert len(inputz.get_targets()) == 1
|
||||
assert len(new_add1.get_targets()) == 1
|
||||
assert len(new_add2.get_args()) == 2
|
||||
assert new_add1.get_targets()[0] == new_add2.get_args()[0]
|
||||
assert inputz.get_targets()[0] == new_add2.get_args()[1]
|
||||
|
||||
assert len(new_add2.get_targets()) == 1
|
||||
assert len(relu.get_args()) == 1
|
||||
assert new_add2.get_targets()[0] == relu.get_args()[0]
|
||||
|
||||
|
||||
class TreeNetwork3(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = Conv2d(16, 16, 1)
|
||||
self.conv2 = Conv2d(16, 16, 3)
|
||||
self.add1 = AddN()
|
||||
self.add2 = AddN()
|
||||
self.relu = ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
y = self.conv1(x)
|
||||
z = self.conv2(x)
|
||||
x = self.add1(y, z, x)
|
||||
x = self.add2(y, z, x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_one_input_to_multi_pattern_tree_pattern():
|
||||
"""
|
||||
Feature: Python api PatternEngine.
|
||||
Description: Construct a multi-to-multi PatternEngine and apply it on a SymbolTree, check SymbolTree after
|
||||
PatternEngine applied.
|
||||
Expectation: Success.
|
||||
"""
|
||||
|
||||
net = TreeNetwork3()
|
||||
stree = SymbolTree(net)
|
||||
conv1 = stree.get_node("conv1")
|
||||
conv2 = stree.get_node("conv2")
|
||||
add1 = stree.get_node("add1")
|
||||
add2 = stree.get_node("add2")
|
||||
relu = stree.get_node("relu")
|
||||
assert conv1 is not None
|
||||
assert conv2 is not None
|
||||
assert add1 is not None
|
||||
assert add2 is not None
|
||||
assert relu is not None
|
||||
construct_ast: ast.FunctionDef = getattr(stree.get_handler(), "_root_ast")
|
||||
assert len(construct_ast.body) == 6
|
||||
assert len(stree.nodes()) == 7
|
||||
|
||||
multi_input_pattern = MultiInputPattern()
|
||||
multi_input_pattern.apply(stree)
|
||||
|
||||
assert len(construct_ast.body) == 4
|
||||
assert len(stree.nodes()) == 5
|
||||
conv1 = stree.get_node("conv1")
|
||||
conv2 = stree.get_node("conv2")
|
||||
add1 = stree.get_node("add1")
|
||||
add2 = stree.get_node("add2")
|
||||
relu = stree.get_node("relu")
|
||||
new_add1 = stree.get_node("new_add1")
|
||||
new_add2 = stree.get_node("new_add2")
|
||||
inputx = stree.get_node("input_x")
|
||||
assert conv1 is None
|
||||
assert conv2 is None
|
||||
assert add1 is None
|
||||
assert add2 is None
|
||||
assert relu is not None
|
||||
assert new_add1 is not None
|
||||
assert new_add2 is not None
|
||||
assert inputx is not None
|
||||
|
||||
# check inputx topological order
|
||||
assert len(inputx.get_users()) == 2
|
||||
assert inputx.get_users()[0] == new_add1
|
||||
assert inputx.get_users()[1] == new_add2
|
||||
# check new_add1 topological order
|
||||
assert len(new_add1.get_inputs()) == 2
|
||||
assert new_add1.get_inputs()[0] == inputx
|
||||
assert new_add1.get_inputs()[1] == inputx
|
||||
assert len(new_add1.get_users()) == 1
|
||||
assert new_add1.get_users()[0] == new_add2
|
||||
# check new_add2 topological order
|
||||
assert len(new_add2.get_inputs()) == 2
|
||||
assert new_add2.get_inputs()[0] == new_add1
|
||||
assert new_add2.get_inputs()[1] == inputx
|
||||
assert len(new_add2.get_users()) == 1
|
||||
assert new_add2.get_users()[0] == relu
|
||||
# check relu topological order
|
||||
assert len(relu.get_inputs()) == 1
|
||||
assert relu.get_inputs()[0] == new_add2
|
||||
# check source code order
|
||||
assert getattr(inputx.get_handler(), "_next") == new_add1.get_handler()
|
||||
assert getattr(new_add1.get_handler(), "_next") == new_add2.get_handler()
|
||||
assert getattr(new_add2.get_handler(), "_next") == relu.get_handler()
|
||||
assert getattr(relu.get_handler(), "_prev") == new_add2.get_handler()
|
||||
assert getattr(new_add2.get_handler(), "_prev") == new_add1.get_handler()
|
||||
assert getattr(new_add1.get_handler(), "_prev") == inputx.get_handler()
|
||||
# check arg edge
|
||||
assert len(inputx.get_targets()) == 1
|
||||
assert len(new_add1.get_args()) == 2
|
||||
assert inputx.get_targets()[0] == new_add1.get_args()[0]
|
||||
assert inputx.get_targets()[0] == new_add1.get_args()[1]
|
||||
|
||||
assert len(inputx.get_targets()) == 1
|
||||
assert len(new_add1.get_targets()) == 1
|
||||
assert len(new_add2.get_args()) == 2
|
||||
assert new_add1.get_targets()[0] == new_add2.get_args()[0]
|
||||
assert inputx.get_targets()[0] == new_add2.get_args()[1]
|
||||
|
||||
assert len(new_add2.get_targets()) == 1
|
||||
assert len(relu.get_args()) == 1
|
||||
assert new_add2.get_targets()[0] == relu.get_args()[0]
|
|
@ -0,0 +1,388 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import ast
|
||||
import inspect
|
||||
|
||||
from mindspore.nn import Cell, Conv2d, BatchNorm2d, ReLU
|
||||
from mindspore.ops import Add
|
||||
from mindspore.rewrite import ScopedValue, ValueType, NodeType
|
||||
from mindspore.rewrite import Node as NodeApi
|
||||
from mindspore.rewrite.symbol_tree import SymbolTree
|
||||
from mindspore.rewrite.node import Node
|
||||
|
||||
|
||||
class Network(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = Conv2d(16, 16, 3)
|
||||
self.bn = BatchNorm2d(16)
|
||||
self.relu1 = ReLU()
|
||||
self.relu2 = ReLU()
|
||||
self.relu3 = ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu1(x)
|
||||
x = self.relu2(x)
|
||||
x = self.relu3(x)
|
||||
return x
|
||||
|
||||
|
||||
def create_symbol_tree():
|
||||
net = Network()
|
||||
source = inspect.getsource(type(net))
|
||||
ast_root = ast.parse(source)
|
||||
ast_module = ast_root
|
||||
assert isinstance(ast_root, ast.Module)
|
||||
ast_class = ast_module.body[0]
|
||||
assert isinstance(ast_class, ast.ClassDef)
|
||||
ast_init_func = ast_class.body[0]
|
||||
assert isinstance(ast_init_func, ast.FunctionDef)
|
||||
ast_construct_func = ast_class.body[1]
|
||||
assert isinstance(ast_construct_func, ast.FunctionDef)
|
||||
ast_conv = ast_construct_func.body[0]
|
||||
ast_bn = ast_construct_func.body[1]
|
||||
ast_relu1 = ast_construct_func.body[2]
|
||||
ast_relu2 = ast_construct_func.body[3]
|
||||
ast_relu3 = ast_construct_func.body[4]
|
||||
ast_return = ast_construct_func.body[5]
|
||||
stree = SymbolTree(net, ast_module)
|
||||
stree.set_class_ast(ast_class)
|
||||
stree.set_init_func_ast(ast_init_func)
|
||||
stree.set_ast_root(ast_construct_func)
|
||||
stree.append_input_node("x")
|
||||
conv_node = Node.create_call_cell(net.conv, ast_conv, [ScopedValue.create_naming_value("x")],
|
||||
ScopedValue.create_naming_value("conv", "self"),
|
||||
[ScopedValue.create_naming_value("x")],
|
||||
{},
|
||||
"conv")
|
||||
stree.append_origin_field(conv_node)
|
||||
bn_node = Node.create_call_cell(net.bn, ast_bn, [ScopedValue.create_naming_value("x")],
|
||||
ScopedValue.create_naming_value("bn", "self"),
|
||||
[ScopedValue.create_naming_value("x")], {},
|
||||
"bn")
|
||||
bn_node = stree.append_origin_field(bn_node)
|
||||
relu1_node = Node.create_call_cell(net.relu1, ast_relu1, [ScopedValue.create_naming_value("x")],
|
||||
ScopedValue.create_naming_value("relu1", "self"),
|
||||
[ScopedValue.create_naming_value("x")],
|
||||
{}, "relu1")
|
||||
relu1_node = stree.append_origin_field(relu1_node)
|
||||
relu2_node = Node.create_call_cell(net.relu2, ast_relu2, [ScopedValue.create_naming_value("x")],
|
||||
ScopedValue.create_naming_value("relu2", "self"),
|
||||
[ScopedValue.create_naming_value("x")],
|
||||
{}, "relu2")
|
||||
relu2_node = stree.append_origin_field(relu2_node)
|
||||
relu3_node = Node.create_call_cell(net.relu3, ast_relu3, [ScopedValue.create_naming_value("x")],
|
||||
ScopedValue.create_naming_value("relu3", "self"),
|
||||
[ScopedValue.create_naming_value("x")],
|
||||
{}, "relu3")
|
||||
stree.append_origin_field(relu3_node)
|
||||
node_return = Node.create_output_node(ast_return, ["x"])
|
||||
stree.append_origin_field(node_return)
|
||||
return stree, bn_node, relu1_node, relu2_node
|
||||
|
||||
|
||||
def test_insert_node():
|
||||
"""
|
||||
Feature: Python api insert_node of SymbolTree of Rewrite.
|
||||
Description: Call insert_node to insert a node into SymbolTree.
|
||||
Expectation: Success.
|
||||
"""
|
||||
stree, _, relu1, relu2 = create_symbol_tree()
|
||||
construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
|
||||
providers = getattr(getattr(stree, "_topo_mgr"), "_target_provider")
|
||||
consumers = getattr(getattr(stree, "_topo_mgr"), "_target_consumer")
|
||||
providers_len = len(providers)
|
||||
consumers_len = len(consumers)
|
||||
assert len(stree.nodes()) == 7
|
||||
assert len(construct_ast.body) == 6
|
||||
assert len(relu1.get_targets()) == 1
|
||||
assert len(relu2.get_normalized_args().values()) == 1
|
||||
assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
|
||||
input1 = 1
|
||||
node = Node.create_call_cell(Add(), None, ['x'], 'new_conv',
|
||||
[ScopedValue.create_naming_value('x'), ScopedValue.create_variable_value(input1)], {},
|
||||
'new_conv')
|
||||
position = stree.before(relu2)
|
||||
node = stree.insert_node(position, node)
|
||||
# check nodes size
|
||||
assert len(stree.nodes()) == 8
|
||||
# check args
|
||||
assert len(relu2.get_normalized_args().values()) == 1
|
||||
assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
|
||||
assert len(node.get_normalized_args().values()) == 2
|
||||
assert list(node.get_normalized_args().values())[0] == ScopedValue.create_naming_value('x')
|
||||
assert list(node.get_normalized_args().values())[1].type == ValueType.IntValue
|
||||
# check provider
|
||||
assert len(providers) == providers_len + 1
|
||||
assert len(node.get_targets()) == 1
|
||||
assert providers.get(node.get_targets()[0])[0] == node
|
||||
assert providers.get(node.get_targets()[0])[1] == 0
|
||||
# check consumer
|
||||
assert len(consumers) == consumers_len + 1
|
||||
assert consumers.get(list(node.get_normalized_args().values())[1]) is not None
|
||||
# check inputs
|
||||
assert len(relu2.get_inputs()) == 1
|
||||
assert relu2.get_inputs()[0] == relu1
|
||||
assert len(node.get_inputs()) == 1
|
||||
assert node.get_inputs()[0].get_node_type() == NodeType.Input
|
||||
# check ast
|
||||
node_ast = node.get_ast()
|
||||
assert isinstance(node_ast, ast.Assign)
|
||||
args = node_ast.value.args
|
||||
assert isinstance(args, list)
|
||||
assert len(args) == 2
|
||||
assert isinstance(args[0], ast.Name)
|
||||
assert isinstance(args[1], ast.Constant)
|
||||
assert len(construct_ast.body) == 7
|
||||
|
||||
|
||||
def test_set_node_arg():
|
||||
"""
|
||||
Feature: Python api set_node_arg of SymbolTree of Rewrite.
|
||||
Description: Call set_node_arg to change topological-order of a node.
|
||||
Expectation: Success.
|
||||
"""
|
||||
stree, bn, relu1, relu2 = create_symbol_tree()
|
||||
assert len(stree.nodes()) == 7
|
||||
assert len(bn.get_targets()) == 1
|
||||
bn_output = bn.get_targets()[0]
|
||||
# check bn topological order
|
||||
assert len(stree.get_node_users(bn)) == 1
|
||||
assert stree.get_node_users(bn)[0][0] == relu1
|
||||
# check relu1 topological order
|
||||
assert len(stree.get_node_inputs(relu1)) == 1
|
||||
assert stree.get_node_inputs(relu1)[0] == bn
|
||||
assert len(stree.get_node_users(relu1)) == 1
|
||||
assert stree.get_node_users(relu1)[0][0] == relu2
|
||||
# check relu2 topological order
|
||||
assert len(stree.get_node_inputs(relu2)) == 1
|
||||
assert stree.get_node_inputs(relu2)[0] == relu1
|
||||
# check relu1 and relu2 edge
|
||||
assert len(relu1.get_targets()) == 1
|
||||
assert len(relu2.get_normalized_args().values()) == 1
|
||||
assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
|
||||
|
||||
stree.set_node_arg(relu2, 0, bn_output)
|
||||
# check bn topological order
|
||||
assert len(stree.get_node_users(bn)) == 2
|
||||
assert stree.get_node_users(bn)[0][0] == relu1
|
||||
assert stree.get_node_users(bn)[1][0] == relu2
|
||||
# check relu1 topological order
|
||||
assert len(stree.get_node_inputs(relu1)) == 1
|
||||
assert stree.get_node_inputs(relu1)[0] == bn
|
||||
assert len(stree.get_node_users(relu1)) == 0
|
||||
# check relu2 topological order
|
||||
assert len(stree.get_node_inputs(relu2)) == 1
|
||||
assert stree.get_node_inputs(relu2)[0] == bn
|
||||
# check bn and relu2 edge
|
||||
assert len(relu1.get_targets()) == 1
|
||||
assert len(relu2.get_normalized_args().values()) == 1
|
||||
assert bn_output == list(relu2.get_normalized_args().values())[0]
|
||||
# check ast
|
||||
node_ast = relu2.get_ast()
|
||||
assert isinstance(node_ast, ast.Assign)
|
||||
args = node_ast.value.args
|
||||
assert isinstance(args, list)
|
||||
assert len(args) == 1
|
||||
assert isinstance(args[0], ast.Name)
|
||||
assert args[0].id == bn_output.value
|
||||
|
||||
|
||||
def test_set_node_arg_by_node():
|
||||
"""
|
||||
Feature: Python api set_node_arg_by_node of SymbolTree of Rewrite.
|
||||
Description: Call set_node_arg_by_node to change topological-order of a node.
|
||||
Expectation: Success.
|
||||
"""
|
||||
stree, bn, relu1, relu2 = create_symbol_tree()
|
||||
assert len(stree.nodes()) == 7
|
||||
assert len(bn.get_targets()) == 1
|
||||
bn_output = bn.get_targets()[0]
|
||||
# check bn topological order
|
||||
assert len(stree.get_node_users(bn)) == 1
|
||||
assert stree.get_node_users(bn)[0][0] == relu1
|
||||
# check relu1 topological order
|
||||
assert len(stree.get_node_inputs(relu1)) == 1
|
||||
assert stree.get_node_inputs(relu1)[0] == bn
|
||||
assert len(stree.get_node_users(relu1)) == 1
|
||||
assert stree.get_node_users(relu1)[0][0] == relu2
|
||||
# check relu2 topological order
|
||||
assert len(stree.get_node_inputs(relu2)) == 1
|
||||
assert stree.get_node_inputs(relu2)[0] == relu1
|
||||
# check relu1 and relu2 edge
|
||||
assert len(relu1.get_targets()) == 1
|
||||
assert len(relu2.get_normalized_args().values()) == 1
|
||||
assert relu1.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
|
||||
|
||||
stree.set_node_arg_by_node(relu2, 0, bn)
|
||||
# check bn topological order
|
||||
assert len(stree.get_node_users(bn)) == 2
|
||||
assert stree.get_node_users(bn)[0][0] == relu1
|
||||
assert stree.get_node_users(bn)[1][0] == relu2
|
||||
# check relu1 topological order
|
||||
assert len(stree.get_node_inputs(relu1)) == 1
|
||||
assert stree.get_node_inputs(relu1)[0] == bn
|
||||
assert len(stree.get_node_users(relu1)) == 0
|
||||
# check relu2 topological order
|
||||
assert len(stree.get_node_inputs(relu2)) == 1
|
||||
assert stree.get_node_inputs(relu2)[0] == bn
|
||||
# check bn and relu2 edge
|
||||
assert len(relu1.get_targets()) == 1
|
||||
assert len(relu2.get_normalized_args().values()) == 1
|
||||
assert bn_output == list(relu2.get_normalized_args().values())[0]
|
||||
# check ast
|
||||
node_ast = relu2.get_ast()
|
||||
assert isinstance(node_ast, ast.Assign)
|
||||
args = node_ast.value.args
|
||||
assert isinstance(args, list)
|
||||
assert len(args) == 1
|
||||
assert isinstance(args[0], ast.Name)
|
||||
assert args[0].id == bn_output.value
|
||||
|
||||
|
||||
def test_erase_succeed():
|
||||
"""
|
||||
Feature: Python api erase_node of SymbolTree of Rewrite.
|
||||
Description: Call erase_node to erase a node from SymbolTree.
|
||||
Expectation: Success.
|
||||
"""
|
||||
stree, bn, relu1, relu2 = create_symbol_tree()
|
||||
construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
|
||||
providers = getattr(getattr(stree, "_topo_mgr"), "_target_provider")
|
||||
providers_len = len(providers)
|
||||
assert len(stree.nodes()) == 7
|
||||
assert len(construct_ast.body) == 6
|
||||
|
||||
stree.set_node_arg_by_node(relu2, 0, bn)
|
||||
stree.erase_node(relu1)
|
||||
|
||||
assert len(stree.nodes()) == 6
|
||||
assert len(providers) == providers_len - 1
|
||||
assert len(construct_ast.body) == 5
|
||||
|
||||
|
||||
def test_erase_failed():
|
||||
"""
|
||||
Feature: Python api erase_node of SymbolTree of Rewrite.
|
||||
Description: Call erase_node to erase a node from SymbolTree which is not isolated.
|
||||
Expectation: Failure.
|
||||
"""
|
||||
stree, _, relu1, _ = create_symbol_tree()
|
||||
catched_error = False
|
||||
try:
|
||||
stree.erase_node(relu1)
|
||||
except RuntimeError:
|
||||
catched_error = True
|
||||
assert catched_error
|
||||
|
||||
|
||||
def test_replace_one_to_one():
|
||||
"""
|
||||
Feature: Python api replace of SymbolTree of Rewrite.
|
||||
Description: Call replace to replace an origin node to a new node.
|
||||
Expectation: Success.
|
||||
"""
|
||||
stree, bn, relu1, relu2 = create_symbol_tree()
|
||||
construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
|
||||
assert len(construct_ast.body) == 6
|
||||
assert len(stree.nodes()) == 7
|
||||
|
||||
new_conv = Conv2d(16, 16, 5)
|
||||
new_conv_node = NodeApi.create_call_cell(new_conv, [ScopedValue.create_naming_value("new_conv")],
|
||||
bn.get_targets()).get_handler()
|
||||
new_conv_node = stree.replace(relu1, [new_conv_node])
|
||||
assert len(stree.nodes()) == 7
|
||||
# check ast
|
||||
assert len(construct_ast.body) == 6
|
||||
node_ast: ast.Assign = construct_ast.body[2]
|
||||
func_ast: ast.Attribute = node_ast.value.func
|
||||
assert func_ast.attr == new_conv_node.get_name()
|
||||
# check bn topological order
|
||||
assert len(stree.get_node_users(bn)) == 1
|
||||
assert stree.get_node_users(bn)[0][0] == new_conv_node
|
||||
# check new_conv_node topological order
|
||||
assert len(stree.get_node_inputs(new_conv_node)) == 1
|
||||
assert stree.get_node_inputs(new_conv_node)[0] == bn
|
||||
assert len(stree.get_node_users(new_conv_node)) == 1
|
||||
assert stree.get_node_users(new_conv_node)[0][0] == relu2
|
||||
# check relu2 topological order
|
||||
assert len(stree.get_node_inputs(relu2)) == 1
|
||||
assert stree.get_node_inputs(relu2)[0] == new_conv_node
|
||||
# check arg edge
|
||||
assert len(bn.get_targets()) == 1
|
||||
assert len(new_conv_node.get_normalized_args().values()) == 1
|
||||
assert bn.get_targets()[0] == list(new_conv_node.get_normalized_args().values())[0]
|
||||
assert len(new_conv_node.get_targets()) == 1
|
||||
assert len(relu2.get_normalized_args().values()) == 1
|
||||
assert new_conv_node.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
|
||||
|
||||
|
||||
def test_replace_one_to_multi():
|
||||
"""
|
||||
Feature: Python api replace of SymbolTree of Rewrite.
|
||||
Description: Call replace to replace an origin node to a new node-tree.
|
||||
Expectation: Success.
|
||||
"""
|
||||
stree, bn, relu1, relu2 = create_symbol_tree()
|
||||
construct_ast: ast.FunctionDef = getattr(stree, "_root_ast")
|
||||
assert len(construct_ast.body) == 6
|
||||
assert len(stree.nodes()) == 7
|
||||
|
||||
new_conv_node = NodeApi.create_call_cell(Conv2d(16, 16, 5), [ScopedValue.create_naming_value("new_conv")],
|
||||
bn.get_targets()).get_handler()
|
||||
new_relu_node = NodeApi.create_call_cell(ReLU(), [ScopedValue.create_naming_value("new_relu")],
|
||||
new_conv_node.get_targets()).get_handler()
|
||||
new_relu_node = stree.replace(relu1, [new_relu_node, new_conv_node])
|
||||
new_conv_node = new_relu_node.get_inputs()[0]
|
||||
|
||||
assert len(stree.nodes()) == 8
|
||||
# check ast
|
||||
assert len(construct_ast.body) == 7
|
||||
new_conv_ast: ast.Assign = construct_ast.body[2]
|
||||
new_conv_func_ast: ast.Attribute = new_conv_ast.value.func
|
||||
assert new_conv_func_ast.attr == new_conv_node.get_name()
|
||||
new_relu_ast: ast.Assign = construct_ast.body[3]
|
||||
new_relu_func_ast: ast.Attribute = new_relu_ast.value.func
|
||||
assert new_relu_func_ast.attr == new_relu_node.get_name()
|
||||
# check bn topological order
|
||||
assert len(stree.get_node_users(bn)) == 1
|
||||
assert stree.get_node_users(bn)[0][0] == new_conv_node
|
||||
# check new_conv_node topological order
|
||||
assert len(stree.get_node_inputs(new_conv_node)) == 1
|
||||
assert stree.get_node_inputs(new_conv_node)[0] == bn
|
||||
assert len(stree.get_node_users(new_conv_node)) == 1
|
||||
assert stree.get_node_users(new_conv_node)[0][0] == new_relu_node
|
||||
# check new_relu_node topological order
|
||||
assert len(stree.get_node_inputs(new_relu_node)) == 1
|
||||
assert stree.get_node_inputs(new_relu_node)[0] == new_conv_node
|
||||
assert len(stree.get_node_users(new_relu_node)) == 1
|
||||
assert stree.get_node_users(new_relu_node)[0][0] == relu2
|
||||
# check relu2 topological order
|
||||
assert len(stree.get_node_inputs(relu2)) == 1
|
||||
assert stree.get_node_inputs(relu2)[0] == new_relu_node
|
||||
# check arg edge
|
||||
assert len(bn.get_targets()) == 1
|
||||
assert len(new_conv_node.get_normalized_args().values()) == 1
|
||||
assert bn.get_targets()[0] == list(new_conv_node.get_normalized_args().values())[0]
|
||||
|
||||
assert len(new_conv_node.get_targets()) == 1
|
||||
assert len(new_relu_node.get_normalized_args().values()) == 1
|
||||
assert new_conv_node.get_targets()[0] == list(new_relu_node.get_normalized_args().values())[0]
|
||||
|
||||
assert len(new_relu_node.get_targets()) == 1
|
||||
assert len(relu2.get_normalized_args().values()) == 1
|
||||
assert new_relu_node.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
|
|
@ -119,6 +119,12 @@ else
|
|||
if [ ${RET} -ne 0 ]; then
|
||||
exit ${RET}
|
||||
fi
|
||||
|
||||
pytest -s $CURRPATH/rewrite/*.py
|
||||
RET=$?
|
||||
if [ ${RET} -ne 0 ]; then
|
||||
exit ${RET}
|
||||
fi
|
||||
fi
|
||||
|
||||
RET=$?
|
||||
|
|
Loading…
Reference in New Issue