forked from mindspore-Ecosystem/mindspore
!49004 rewrite增加符号数节点MathOps,用于支持加减乘除等数学运算操作
Merge pull request !49004 from GuoZhibin/add_rewrite_node_mathops
This commit is contained in:
commit
212213ea60
|
@ -29,6 +29,7 @@ class NodeType(Enum):
|
|||
- Input: `Input` node represents input of `SymbolTree` corresponding to arguments of forward method.
|
||||
- Output: `Output` node represents output of SymbolTree corresponding to return statement of forward method.
|
||||
- Tree: `Tree` node represents sub-network invoking in forward method.
|
||||
- MathOps: `MathOps` node represents a mathematical operation, such as adding or comparing in forward method.
|
||||
|
||||
"""
|
||||
Unknown = 0
|
||||
|
@ -44,3 +45,4 @@ class NodeType(Enum):
|
|||
Output = 8
|
||||
Tree = 9
|
||||
CellContainer = 10
|
||||
MathOps = 11
|
||||
|
|
|
@ -461,7 +461,7 @@ class AstModifier(ast.NodeTransformer):
|
|||
|
||||
Args:
|
||||
src_argument (ScopedValue): An instance of ScopedValue represents new argument.
|
||||
dst_ast (ast.AST): Targets of ast.Assign.
|
||||
dst_ast (ast.AST): Ast node to be updated by ScopedValue.
|
||||
|
||||
Raises:
|
||||
TypeError: Input src_argument is not a ScopedValue
|
||||
|
@ -492,6 +492,12 @@ class AstModifier(ast.NodeTransformer):
|
|||
str(src_argument.type))
|
||||
dst_ast.n = src_argument.value
|
||||
return
|
||||
if isinstance(dst_ast, ast.Str):
|
||||
if src_argument.type not in [ValueType.StringValue]:
|
||||
raise RuntimeError("src_argument should be a StringValue, but got:",
|
||||
str(src_argument.type))
|
||||
dst_ast.s = src_argument.value
|
||||
return
|
||||
if isinstance(dst_ast, ast.Name):
|
||||
if src_argument.type not in [ValueType.NamingValue, ValueType.StringValue]:
|
||||
raise RuntimeError("src_argument.type should be ValueType.NamingValue or ValueType.StringValue.")
|
||||
|
|
|
@ -38,7 +38,8 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|||
ast.Call: ["args"],
|
||||
ast.BinOp: ["left", "right"],
|
||||
ast.BoolOp: ["values"],
|
||||
ast.unaryop: ["operand"],
|
||||
ast.UnaryOp: ["operand"],
|
||||
ast.Compare: ["left", "comparators"],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
@ -55,7 +56,7 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|||
target_name = "function"
|
||||
elif isinstance(node, ast.Return):
|
||||
target_name = "return_value"
|
||||
elif isinstance(node, (ast.BinOp, ast.boolop, ast.UnaryOp)):
|
||||
elif isinstance(node, (ast.BinOp, ast.BoolOp, ast.UnaryOp)):
|
||||
target_name = type(node.op).__name__.lower() + "_var"
|
||||
elif isinstance(node, ast.Tuple):
|
||||
target_name = type(node).__name__.lower() + "_var"
|
||||
|
|
|
@ -222,6 +222,32 @@ class Node:
|
|||
return cls(NodeType.Output, ast_node, None, ScopedValue.create_naming_value("return"), real_return_values, {},
|
||||
name, None)
|
||||
|
||||
@classmethod
|
||||
def create_mathops_node(cls, ast_node: ast.AST, targets: [ScopedValue],
|
||||
op_type: ScopedValue, args: [ScopedValue],
|
||||
ops: {str: list}, name: str = ""):
|
||||
"""
|
||||
Class method of Node. Instantiate an instance of node whose type is `MathOps` .
|
||||
A mathops node is used to represent a node with mathematical operations, such as
|
||||
`y = a + b` , `y = not a` , `y = 0 < a < 1`, `y = a or b` , etc.
|
||||
|
||||
Args:
|
||||
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast. The type of
|
||||
node is ast.Assign, and the type of ast_node.value is one of ast.BinOp, ast.UnaryOp, ast.BoolOp and
|
||||
ast.Compare.
|
||||
targets (list[ScopedValue]): Targets of mathematical operations. A list of instance of `ScopedValue`.
|
||||
See detail in docstring of Node class.
|
||||
op_type (ScopedValue): The type of ast_node.value saved by string. A ScopedValue with NamingValue type.
|
||||
args (list[ScopedValue]): Values participating in the mathematical operations. All values are saved
|
||||
sequentially in the list.
|
||||
ops (dict[str:ScopedValue]): Operators participating in the mathematical operations. All operators are
|
||||
saved sequentially in the dict, and keys are numbers in string format, such as {'0':'add', '1':'sub'}.
|
||||
name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`.
|
||||
Name of node also used as field name in network class. The format of mathops node name
|
||||
is 'AstNodeName_AstOpName_n'.
|
||||
"""
|
||||
return cls(NodeType.MathOps, ast_node, targets, op_type, args, ops, name, None)
|
||||
|
||||
@staticmethod
|
||||
def create_call_op(op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
|
||||
func: Union[ScopedValue, str], args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None,
|
||||
|
@ -624,7 +650,8 @@ class Node:
|
|||
"""
|
||||
self._targets = targets
|
||||
if self._node_type in (NodeType.CallCell, NodeType.CallMethod, NodeType.CallPrimitive,
|
||||
NodeType.Tree, NodeType.CallFunction, NodeType.CellContainer):
|
||||
NodeType.Tree, NodeType.CallFunction, NodeType.CellContainer,
|
||||
NodeType.MathOps):
|
||||
self._sync_assign_targets_to_ast()
|
||||
|
||||
def get_func(self) -> ScopedValue:
|
||||
|
@ -1133,14 +1160,50 @@ class Node:
|
|||
raise RuntimeError("Unsupported return value type: ", return_value_ast)
|
||||
ast.fix_missing_locations(return_ast)
|
||||
|
||||
def _sync_mathops_node_args_to_ast(self):
|
||||
"""
|
||||
Sync values from self._normalized_args to the ast node for mathematical operations.
|
||||
"""
|
||||
if self._ast_node is None:
|
||||
return
|
||||
if not isinstance(self._ast_node, ast.Assign):
|
||||
raise TypeError(f"type of node should be ast.Assign, but got {type(self._ast_node)}")
|
||||
mathops_node = self._ast_node.value
|
||||
if isinstance(mathops_node, ast.BinOp):
|
||||
left = mathops_node.left
|
||||
right = mathops_node.right
|
||||
AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[0]), left)
|
||||
AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[1]), right)
|
||||
elif isinstance(mathops_node, ast.UnaryOp):
|
||||
operand = mathops_node.operand
|
||||
AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[0]), operand)
|
||||
elif isinstance(mathops_node, ast.BoolOp):
|
||||
values = mathops_node.values
|
||||
for arg_index in range(self._args_num):
|
||||
arg_value = self._normalized_args.get(self._normalized_args_keys[arg_index])
|
||||
AstModifier.update_arg_value(arg_value, values[arg_index])
|
||||
elif isinstance(mathops_node, ast.Compare):
|
||||
left = mathops_node.left
|
||||
AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[0]), left)
|
||||
comparators = mathops_node.comparators
|
||||
for arg_index in range(1, self._args_num):
|
||||
arg_value = self._normalized_args.get(self._normalized_args_keys[arg_index])
|
||||
AstModifier.update_arg_value(arg_value, comparators[arg_index - 1])
|
||||
else:
|
||||
raise TypeError("The type of 'mathops_node' must be one of (ast.BinOp, ast.UnaryOp, "
|
||||
"ast.BoolOp, ast.Compare), but got ", type(mathops_node))
|
||||
|
||||
def _sync_arg(self):
|
||||
"""Sync _normalized_args to corresponding ast node when updated."""
|
||||
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree, NodeType.CellContainer):
|
||||
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree,\
|
||||
NodeType.CellContainer, NodeType.CallFunction):
|
||||
self._sync_call_cell_args_to_ast()
|
||||
elif self._node_type == NodeType.Output:
|
||||
self._sync_return_node_to_ast()
|
||||
elif self._node_type == NodeType.CallMethod:
|
||||
self._sync_call_method_args_to_ast()
|
||||
elif self._node_type == NodeType.MathOps:
|
||||
self._sync_mathops_node_args_to_ast()
|
||||
|
||||
|
||||
class TreeNode(Node):
|
||||
|
|
|
@ -466,6 +466,62 @@ class AssignParser(Parser):
|
|||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _convert_ast_mathops_to_node(ast_node: Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare],
|
||||
father_ast_node: ast.Assign) -> Node:
|
||||
"""
|
||||
Convert ast node of math operations(ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare) to
|
||||
a symbol tree node.
|
||||
|
||||
Args:
|
||||
ast_node (Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare]): An assign node with mathematival
|
||||
operation in construct function.
|
||||
father_ast_node (ast.Assign): Assign node in construct.
|
||||
|
||||
Returns:
|
||||
An instance of Node in Symbol Tree.
|
||||
|
||||
Raises:
|
||||
TypeError: The type of parameter 'ast_node' is not in (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare).
|
||||
|
||||
"""
|
||||
if not isinstance(ast_node, (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare)):
|
||||
raise TypeError("The type of parameter 'ast_node' must be one of (ast.BinOp, ast.UnaryOp, "
|
||||
"ast.BoolOp, ast.Compare), but got ", type(ast_node))
|
||||
|
||||
targets = AssignParser._get_targets(AssignParser._create_scopedvalue(father_ast_node.targets[0]))
|
||||
args = []
|
||||
op_type_str = type(ast_node).__name__
|
||||
op_type = ScopedValue.create_naming_value(op_type_str)
|
||||
ops = {}
|
||||
name = op_type_str
|
||||
if isinstance(ast_node, ast.BinOp):
|
||||
op = type(ast_node.op).__name__
|
||||
name = name + '_' + op
|
||||
ops['0'] = ScopedValue.create_naming_value(op)
|
||||
args.append(AssignParser._create_scopedvalue(ast_node.left))
|
||||
args.append(AssignParser._create_scopedvalue(ast_node.right))
|
||||
elif isinstance(ast_node, ast.UnaryOp):
|
||||
op = type(ast_node.op).__name__
|
||||
name = name + '_' + op
|
||||
ops['0'] = ScopedValue.create_naming_value(op)
|
||||
args.append(AssignParser._create_scopedvalue(ast_node.operand))
|
||||
elif isinstance(ast_node, ast.BoolOp):
|
||||
op = type(ast_node.op).__name__
|
||||
name = name + '_' + op
|
||||
ops['0'] = ScopedValue.create_naming_value(op)
|
||||
for value in ast_node.values:
|
||||
args.append(AssignParser._create_scopedvalue(value))
|
||||
elif isinstance(ast_node, ast.Compare):
|
||||
args.append(AssignParser._create_scopedvalue(ast_node.left))
|
||||
for idx, ast_op in enumerate(ast_node.ops):
|
||||
op = type(ast_op).__name__
|
||||
name = name + '_' + op
|
||||
ops[str(idx)] = ScopedValue.create_naming_value(op)
|
||||
args.append(AssignParser._create_scopedvalue(ast_node.comparators[idx]))
|
||||
name = name.lower()
|
||||
return Node.create_mathops_node(father_ast_node, targets, op_type, args, ops, name)
|
||||
|
||||
def process(self, stree: SymbolTree, node: ast.Assign):
|
||||
"""
|
||||
Parse ast.Assign and create a node in symbol tree.
|
||||
|
@ -493,11 +549,10 @@ class AssignParser(Parser):
|
|||
if isinstance(value, ast.Call):
|
||||
node_ = self._convert_ast_call_to_node(value, node, stree)
|
||||
stree.append_origin_field(node_)
|
||||
elif isinstance(value, ast.BinOp):
|
||||
logger.info(f"ops-call({astunparse.unparse(node)}) in assign will be supported in near feature, "
|
||||
f"ignored as a python node now")
|
||||
stree.try_append_python_node(node, node)
|
||||
elif isinstance(value, (ast.BoolOp, ast.Subscript)):
|
||||
elif isinstance(value, (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare)):
|
||||
node_ = AssignParser._convert_ast_mathops_to_node(value, node)
|
||||
stree.append_origin_field(node_)
|
||||
elif isinstance(value, ast.Subscript):
|
||||
logger.info(f"ops-call({astunparse.unparse(node)}) in assign will be supported in near feature, "
|
||||
f"ignored as a python node now")
|
||||
stree.try_append_python_node(node, node)
|
||||
|
|
|
@ -1337,17 +1337,13 @@ class SymbolTree(Observer, Observable):
|
|||
tmp_module_name = tmp_module_file[:-3]
|
||||
sys.path.append(tmp_module_path)
|
||||
tmp_module = None
|
||||
i = 0
|
||||
while not tmp_module:
|
||||
try:
|
||||
tmp_module = importlib.import_module(tmp_module_name)
|
||||
except ModuleNotFoundError:
|
||||
while i > 10:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
i += 1
|
||||
|
||||
try:
|
||||
tmp_module = importlib.import_module(tmp_module_name)
|
||||
except ModuleNotFoundError:
|
||||
time.sleep(1)
|
||||
if not tmp_module:
|
||||
logger.error(f"load module {tmp_module_name} failed.")
|
||||
tmp_module = importlib.import_module(tmp_module_name)
|
||||
network_cls = getattr(tmp_module, self._opt_cls_name)
|
||||
if network_cls is None:
|
||||
raise RuntimeError("Can not find network class:", self._opt_cls_name)
|
||||
|
|
Loading…
Reference in New Issue