!49004 rewrite增加符号数节点MathOps,用于支持加减乘除等数学运算操作

Merge pull request !49004 from GuoZhibin/add_rewrite_node_mathops
This commit is contained in:
i-robot 2023-02-21 02:18:40 +00:00 committed by Gitee
commit 212213ea60
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 143 additions and 20 deletions

View File

@ -29,6 +29,7 @@ class NodeType(Enum):
- Input: `Input` node represents input of `SymbolTree` corresponding to arguments of forward method.
- Output: `Output` node represents output of SymbolTree corresponding to return statement of forward method.
- Tree: `Tree` node represents sub-network invoking in forward method.
- MathOps: `MathOps` node represents a mathematical operation, such as adding or comparing in forward method.
"""
Unknown = 0
@ -44,3 +45,4 @@ class NodeType(Enum):
Output = 8
Tree = 9
CellContainer = 10
MathOps = 11

View File

@ -461,7 +461,7 @@ class AstModifier(ast.NodeTransformer):
Args:
src_argument (ScopedValue): An instance of ScopedValue represents new argument.
dst_ast (ast.AST): Targets of ast.Assign.
dst_ast (ast.AST): Ast node to be updated by ScopedValue.
Raises:
TypeError: Input src_argument is not a ScopedValue
@ -492,6 +492,12 @@ class AstModifier(ast.NodeTransformer):
str(src_argument.type))
dst_ast.n = src_argument.value
return
if isinstance(dst_ast, ast.Str):
if src_argument.type not in [ValueType.StringValue]:
raise RuntimeError("src_argument should be a StringValue, but got:",
str(src_argument.type))
dst_ast.s = src_argument.value
return
if isinstance(dst_ast, ast.Name):
if src_argument.type not in [ValueType.NamingValue, ValueType.StringValue]:
raise RuntimeError("src_argument.type should be ValueType.NamingValue or ValueType.StringValue.")

View File

@ -38,7 +38,8 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
ast.Call: ["args"],
ast.BinOp: ["left", "right"],
ast.BoolOp: ["values"],
ast.unaryop: ["operand"],
ast.UnaryOp: ["operand"],
ast.Compare: ["left", "comparators"],
}
@staticmethod
@ -55,7 +56,7 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
target_name = "function"
elif isinstance(node, ast.Return):
target_name = "return_value"
elif isinstance(node, (ast.BinOp, ast.boolop, ast.UnaryOp)):
elif isinstance(node, (ast.BinOp, ast.BoolOp, ast.UnaryOp)):
target_name = type(node.op).__name__.lower() + "_var"
elif isinstance(node, ast.Tuple):
target_name = type(node).__name__.lower() + "_var"

View File

@ -222,6 +222,32 @@ class Node:
return cls(NodeType.Output, ast_node, None, ScopedValue.create_naming_value("return"), real_return_values, {},
name, None)
@classmethod
def create_mathops_node(cls, ast_node: ast.AST, targets: [ScopedValue],
op_type: ScopedValue, args: [ScopedValue],
ops: {str: list}, name: str = ""):
"""
Class method of Node. Instantiate an instance of node whose type is `MathOps` .
A mathops node is used to represent a node with mathematical operations, such as
`y = a + b` , `y = not a` , `y = 0 < a < 1`, `y = a or b` , etc.
Args:
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast. The type of
node is ast.Assign, and the type of ast_node.value is one of ast.BinOp, ast.UnaryOp, ast.BoolOp and
ast.Compare.
targets (list[ScopedValue]): Targets of mathematical operations. A list of instance of `ScopedValue`.
See detail in docstring of Node class.
op_type (ScopedValue): The type of ast_node.value saved by string. A ScopedValue with NamingValue type.
args (list[ScopedValue]): Values participating in the mathematical operations. All values are saved
sequentially in the list.
ops (dict[str:ScopedValue]): Operators participating in the mathematical operations. All operators are
saved sequentially in the dict, and keys are numbers in string format, such as {'0':'add', '1':'sub'}.
name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`.
Name of node also used as field name in network class. The format of mathops node name
is 'AstNodeName_AstOpName_n'.
"""
return cls(NodeType.MathOps, ast_node, targets, op_type, args, ops, name, None)
@staticmethod
def create_call_op(op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
func: Union[ScopedValue, str], args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None,
@ -624,7 +650,8 @@ class Node:
"""
self._targets = targets
if self._node_type in (NodeType.CallCell, NodeType.CallMethod, NodeType.CallPrimitive,
NodeType.Tree, NodeType.CallFunction, NodeType.CellContainer):
NodeType.Tree, NodeType.CallFunction, NodeType.CellContainer,
NodeType.MathOps):
self._sync_assign_targets_to_ast()
def get_func(self) -> ScopedValue:
@ -1133,14 +1160,50 @@ class Node:
raise RuntimeError("Unsupported return value type: ", return_value_ast)
ast.fix_missing_locations(return_ast)
def _sync_mathops_node_args_to_ast(self):
"""
Sync values from self._normalized_args to the ast node for mathematical operations.
"""
if self._ast_node is None:
return
if not isinstance(self._ast_node, ast.Assign):
raise TypeError(f"type of node should be ast.Assign, but got {type(self._ast_node)}")
mathops_node = self._ast_node.value
if isinstance(mathops_node, ast.BinOp):
left = mathops_node.left
right = mathops_node.right
AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[0]), left)
AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[1]), right)
elif isinstance(mathops_node, ast.UnaryOp):
operand = mathops_node.operand
AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[0]), operand)
elif isinstance(mathops_node, ast.BoolOp):
values = mathops_node.values
for arg_index in range(self._args_num):
arg_value = self._normalized_args.get(self._normalized_args_keys[arg_index])
AstModifier.update_arg_value(arg_value, values[arg_index])
elif isinstance(mathops_node, ast.Compare):
left = mathops_node.left
AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[0]), left)
comparators = mathops_node.comparators
for arg_index in range(1, self._args_num):
arg_value = self._normalized_args.get(self._normalized_args_keys[arg_index])
AstModifier.update_arg_value(arg_value, comparators[arg_index - 1])
else:
raise TypeError("The type of 'mathops_node' must be one of (ast.BinOp, ast.UnaryOp, "
"ast.BoolOp, ast.Compare), but got ", type(mathops_node))
def _sync_arg(self):
"""Sync _normalized_args to corresponding ast node when updated."""
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree, NodeType.CellContainer):
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree,\
NodeType.CellContainer, NodeType.CallFunction):
self._sync_call_cell_args_to_ast()
elif self._node_type == NodeType.Output:
self._sync_return_node_to_ast()
elif self._node_type == NodeType.CallMethod:
self._sync_call_method_args_to_ast()
elif self._node_type == NodeType.MathOps:
self._sync_mathops_node_args_to_ast()
class TreeNode(Node):

View File

@ -466,6 +466,62 @@ class AssignParser(Parser):
return False
return True
@staticmethod
def _convert_ast_mathops_to_node(ast_node: Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare],
father_ast_node: ast.Assign) -> Node:
"""
Convert ast node of math operations(ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare) to
a symbol tree node.
Args:
ast_node (Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare]): An assign node with mathematival
operation in construct function.
father_ast_node (ast.Assign): Assign node in construct.
Returns:
An instance of Node in Symbol Tree.
Raises:
TypeError: The type of parameter 'ast_node' is not in (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare).
"""
if not isinstance(ast_node, (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare)):
raise TypeError("The type of parameter 'ast_node' must be one of (ast.BinOp, ast.UnaryOp, "
"ast.BoolOp, ast.Compare), but got ", type(ast_node))
targets = AssignParser._get_targets(AssignParser._create_scopedvalue(father_ast_node.targets[0]))
args = []
op_type_str = type(ast_node).__name__
op_type = ScopedValue.create_naming_value(op_type_str)
ops = {}
name = op_type_str
if isinstance(ast_node, ast.BinOp):
op = type(ast_node.op).__name__
name = name + '_' + op
ops['0'] = ScopedValue.create_naming_value(op)
args.append(AssignParser._create_scopedvalue(ast_node.left))
args.append(AssignParser._create_scopedvalue(ast_node.right))
elif isinstance(ast_node, ast.UnaryOp):
op = type(ast_node.op).__name__
name = name + '_' + op
ops['0'] = ScopedValue.create_naming_value(op)
args.append(AssignParser._create_scopedvalue(ast_node.operand))
elif isinstance(ast_node, ast.BoolOp):
op = type(ast_node.op).__name__
name = name + '_' + op
ops['0'] = ScopedValue.create_naming_value(op)
for value in ast_node.values:
args.append(AssignParser._create_scopedvalue(value))
elif isinstance(ast_node, ast.Compare):
args.append(AssignParser._create_scopedvalue(ast_node.left))
for idx, ast_op in enumerate(ast_node.ops):
op = type(ast_op).__name__
name = name + '_' + op
ops[str(idx)] = ScopedValue.create_naming_value(op)
args.append(AssignParser._create_scopedvalue(ast_node.comparators[idx]))
name = name.lower()
return Node.create_mathops_node(father_ast_node, targets, op_type, args, ops, name)
def process(self, stree: SymbolTree, node: ast.Assign):
"""
Parse ast.Assign and create a node in symbol tree.
@ -493,11 +549,10 @@ class AssignParser(Parser):
if isinstance(value, ast.Call):
node_ = self._convert_ast_call_to_node(value, node, stree)
stree.append_origin_field(node_)
elif isinstance(value, ast.BinOp):
logger.info(f"ops-call({astunparse.unparse(node)}) in assign will be supported in near feature, "
f"ignored as a python node now")
stree.try_append_python_node(node, node)
elif isinstance(value, (ast.BoolOp, ast.Subscript)):
elif isinstance(value, (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare)):
node_ = AssignParser._convert_ast_mathops_to_node(value, node)
stree.append_origin_field(node_)
elif isinstance(value, ast.Subscript):
logger.info(f"ops-call({astunparse.unparse(node)}) in assign will be supported in near feature, "
f"ignored as a python node now")
stree.try_append_python_node(node, node)

View File

@ -1337,17 +1337,13 @@ class SymbolTree(Observer, Observable):
tmp_module_name = tmp_module_file[:-3]
sys.path.append(tmp_module_path)
tmp_module = None
i = 0
while not tmp_module:
try:
tmp_module = importlib.import_module(tmp_module_name)
except ModuleNotFoundError:
while i > 10:
break
time.sleep(0.1)
i += 1
try:
tmp_module = importlib.import_module(tmp_module_name)
except ModuleNotFoundError:
time.sleep(1)
if not tmp_module:
logger.error(f"load module {tmp_module_name} failed.")
tmp_module = importlib.import_module(tmp_module_name)
network_cls = getattr(tmp_module, self._opt_cls_name)
if network_cls is None:
raise RuntimeError("Can not find network class:", self._opt_cls_name)