fix rewrite bug
This commit is contained in:
parent
a95fd90b07
commit
fa60ab7c84
|
@ -253,8 +253,29 @@ class SymbolTree:
|
||||||
def get_network(self) -> Cell:
|
def get_network(self) -> Cell:
|
||||||
"""
|
"""
|
||||||
Get modified network.
|
Get modified network.
|
||||||
|
The source code of network is saved to a file, the default file name is `network_define.py`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A network object.
|
A network object.
|
||||||
"""
|
"""
|
||||||
return self._symbol_tree.get_network()
|
return self._symbol_tree.get_network()
|
||||||
|
|
||||||
|
def set_saved_file_name(self, file_name: str):
|
||||||
|
"""
|
||||||
|
Set the name of the file used to save the network.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_name (str): filename to be set.
|
||||||
|
"""
|
||||||
|
Validator.check_value_type("file_name", file_name, [str], "Saving network")
|
||||||
|
self._symbol_tree.set_saved_file_name(file_name)
|
||||||
|
|
||||||
|
def get_saved_file_name(self):
|
||||||
|
"""Gets the filename used to save the network."""
|
||||||
|
return self._symbol_tree.get_saved_file_name()
|
||||||
|
|
||||||
|
def save_network_to_file(self):
|
||||||
|
"""
|
||||||
|
Save the modified network to a file. Default file name is `network_define.py`.
|
||||||
|
"""
|
||||||
|
self._symbol_tree.save_network_to_file()
|
||||||
|
|
|
@ -256,6 +256,75 @@ class AstModifier(ast.NodeTransformer):
|
||||||
ast.fix_missing_locations(result)
|
ast.fix_missing_locations(result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_arg_by_single_value(value: ScopedValue):
|
||||||
|
"""
|
||||||
|
Create an instance of ast.Constant.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value (ScopedValue): value used to create arg.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: if scope of value is not empty.
|
||||||
|
TypeError: type of arg not in [ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ast.Constant: An instance of ast.Constant
|
||||||
|
"""
|
||||||
|
if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
|
||||||
|
if value.scope:
|
||||||
|
raise RuntimeError("For arg the scope should be empty")
|
||||||
|
return ast.Constant(value=value.value, kind=None)
|
||||||
|
raise TypeError("Type of arg only support [ValueType.IntValue, ValueType.FloatValue,"
|
||||||
|
f" ValueType.StringValue], but got {type(value)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_list_or_tuple(value: ScopedValue):
|
||||||
|
"""
|
||||||
|
Create an instance of ast.List or ast.Tuple.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value (ScopedValue): value used to create ast node.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ast.List or ast.Tuple: An instance of ast.List or ast.Tuple.
|
||||||
|
"""
|
||||||
|
elts = []
|
||||||
|
for v in value.value:
|
||||||
|
elts.append(AstModifier._create_arg_by_single_value(v))
|
||||||
|
if isinstance(value, list):
|
||||||
|
return ast.List(elts=elts)
|
||||||
|
return ast.Tuple(elts=elts)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_keyword(arg: str, value: ScopedValue):
|
||||||
|
"""
|
||||||
|
Create an instance of ast.keyword.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arg (str): key of keyword.
|
||||||
|
value (ScopedValue): value used to create ast.keywrod instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: if scope of value is not empty.
|
||||||
|
TypeError: type of arg not in [ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue,
|
||||||
|
ValueType.ListValue, ValueType.TupleValue]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ast.keyword: a instance of ast.keyword.
|
||||||
|
"""
|
||||||
|
if value.scope:
|
||||||
|
raise RuntimeError("value.scope should be empty")
|
||||||
|
if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
|
||||||
|
v = ast.Constant(value=value.value, kind=None)
|
||||||
|
elif value.type in (ValueType.ListValue, ValueType.TupleValue):
|
||||||
|
v = AstModifier._create_list_or_tuple(value)
|
||||||
|
else:
|
||||||
|
raise TypeError("Type of keyword value only support [ValueType.IntValue, ValueType.FloatValue,"
|
||||||
|
"ValueType.StringValue, ValueType.ListValue, ValueType.TupleValue],"
|
||||||
|
f"but got {type(value)}")
|
||||||
|
return ast.keyword(arg=arg, value=v)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_call_args(args: [ScopedValue]) -> [ast.AST]:
|
def _create_call_args(args: [ScopedValue]) -> [ast.AST]:
|
||||||
"""
|
"""
|
||||||
|
@ -279,14 +348,14 @@ class AstModifier(ast.NodeTransformer):
|
||||||
if not isinstance(arg, ScopedValue):
|
if not isinstance(arg, ScopedValue):
|
||||||
raise TypeError("arg should be ScopedValue, got: ", type(arg))
|
raise TypeError("arg should be ScopedValue, got: ", type(arg))
|
||||||
if arg.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
|
if arg.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
|
||||||
if arg.scope:
|
results.append(AstModifier._create_arg_by_single_value(arg))
|
||||||
raise RuntimeError("arg.scope should be empty")
|
|
||||||
results.append(ast.Constant(value=arg.value, kind=None))
|
|
||||||
elif arg.type == ValueType.NamingValue:
|
elif arg.type == ValueType.NamingValue:
|
||||||
if arg.scope:
|
if arg.scope:
|
||||||
results.append(ast.Attribute(ast.Name(arg.scope, ast.Load()), arg.value, ast.Store()))
|
results.append(ast.Attribute(ast.Name(arg.scope, ast.Load()), arg.value, ast.Store()))
|
||||||
else:
|
else:
|
||||||
results.append(ast.Name(arg.value, ast.Store()))
|
results.append(ast.Name(arg.value, ast.Store()))
|
||||||
|
elif arg.type == ValueType.ListValue or arg.type == ValueType.TupleValue:
|
||||||
|
results.append(AstModifier._create_list_or_tuple(arg))
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Please handle custom-object first")
|
raise RuntimeError("Please handle custom-object first")
|
||||||
return results
|
return results
|
||||||
|
@ -313,10 +382,9 @@ class AstModifier(ast.NodeTransformer):
|
||||||
for arg, value in kwargs.items():
|
for arg, value in kwargs.items():
|
||||||
if not isinstance(value, ScopedValue):
|
if not isinstance(value, ScopedValue):
|
||||||
raise TypeError("value should be ScopedValue, got: ", type(value))
|
raise TypeError("value should be ScopedValue, got: ", type(value))
|
||||||
if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
|
if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue,
|
||||||
if value.scope:
|
ValueType.ListValue, ValueType.TupleValue):
|
||||||
raise RuntimeError("value.scope should be empty")
|
results.append(AstModifier._create_keyword(arg, value))
|
||||||
results.append(ast.keyword(arg=arg, value=ast.Constant(value=value.value, kind=None)))
|
|
||||||
elif value.type == ValueType.NamingValue:
|
elif value.type == ValueType.NamingValue:
|
||||||
if value.scope:
|
if value.scope:
|
||||||
results.append(ast.keyword(arg=arg, value=ast.Attribute(ast.Name(value.scope, ast.Load()),
|
results.append(ast.keyword(arg=arg, value=ast.Attribute(ast.Name(value.scope, ast.Load()),
|
||||||
|
|
|
@ -13,11 +13,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""SymbolTree class define of Rewrite according to forward function of a network."""
|
"""SymbolTree class define of Rewrite according to forward function of a network."""
|
||||||
|
import stat
|
||||||
from typing import Optional, Union, Tuple, Any
|
from typing import Optional, Union, Tuple, Any
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import ast
|
import ast
|
||||||
import tempfile
|
import importlib
|
||||||
import astunparse
|
import astunparse
|
||||||
|
|
||||||
from mindspore.nn import Cell
|
from mindspore.nn import Cell
|
||||||
|
@ -176,6 +177,7 @@ class SymbolTree(Observer, Observable):
|
||||||
|
|
||||||
self._tmp_file_limits = 20
|
self._tmp_file_limits = 20
|
||||||
self._tmp_files = []
|
self._tmp_files = []
|
||||||
|
self._saved_file_name = "./network_define.py"
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
for tmp_file in self._tmp_files:
|
for tmp_file in self._tmp_files:
|
||||||
|
@ -923,6 +925,27 @@ class SymbolTree(Observer, Observable):
|
||||||
cls = self._get_cls_through_file()
|
cls = self._get_cls_through_file()
|
||||||
return cls(self._global_vars)
|
return cls(self._global_vars)
|
||||||
|
|
||||||
|
def set_saved_file_name(self, file_name: str):
|
||||||
|
"""Sets the filename used to save the network."""
|
||||||
|
if file_name.endswith(".py"):
|
||||||
|
self._saved_file_name = file_name
|
||||||
|
else:
|
||||||
|
self._saved_file_name = file_name + ".py"
|
||||||
|
|
||||||
|
def get_saved_file_name(self):
|
||||||
|
"""Gets the filename used to save the network."""
|
||||||
|
return self._saved_file_name
|
||||||
|
|
||||||
|
def save_network_to_file(self):
|
||||||
|
"""Save the modified network to a file."""
|
||||||
|
abs_path = os.path.abspath(self._saved_file_name)
|
||||||
|
if os.path.isfile(abs_path):
|
||||||
|
os.remove(abs_path)
|
||||||
|
with os.fdopen(os.open(self._saved_file_name, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
|
||||||
|
source = self.get_code()
|
||||||
|
f.write(source.encode('utf-8'))
|
||||||
|
f.flush()
|
||||||
|
|
||||||
def _remove_unused_import(self):
|
def _remove_unused_import(self):
|
||||||
"""remove unused import in self._module_ast"""
|
"""remove unused import in self._module_ast"""
|
||||||
str_checker = StrChecker(self._module_ast)
|
str_checker = StrChecker(self._module_ast)
|
||||||
|
@ -1177,19 +1200,12 @@ class SymbolTree(Observer, Observable):
|
||||||
Returns:
|
Returns:
|
||||||
A class handle.
|
A class handle.
|
||||||
"""
|
"""
|
||||||
source = self.get_code()
|
self.save_network_to_file()
|
||||||
tmp_file = tempfile.NamedTemporaryFile(suffix='.py')
|
tmp_module_path, tmp_module_file = os.path.split(self._saved_file_name)
|
||||||
tmp_file.write(source.encode('utf8'))
|
|
||||||
tmp_file.flush()
|
|
||||||
tmp_file_name = tmp_file.name
|
|
||||||
if len(self._tmp_files) >= self._tmp_file_limits:
|
|
||||||
raise RuntimeError(f"Too many tmp file generated, it may caused by calling get_network method too much "
|
|
||||||
f"times. Only support open {self._tmp_file_limits} tmp file at most now!")
|
|
||||||
self._tmp_files.append(tmp_file)
|
|
||||||
tmp_module_path, tmp_module_file = os.path.split(tmp_file_name)
|
|
||||||
tmp_module_name = tmp_module_file[:-3]
|
tmp_module_name = tmp_module_file[:-3]
|
||||||
sys.path.append(tmp_module_path)
|
sys.path.append(tmp_module_path)
|
||||||
tmp_module = __import__(tmp_module_name)
|
tmp_module = importlib.import_module(tmp_module_name)
|
||||||
|
importlib.reload(tmp_module)
|
||||||
network_cls = getattr(tmp_module, self._opt_cls_name)
|
network_cls = getattr(tmp_module, self._opt_cls_name)
|
||||||
if network_cls is None:
|
if network_cls is None:
|
||||||
raise RuntimeError("Can not find network class:", self._opt_cls_name)
|
raise RuntimeError("Can not find network class:", self._opt_cls_name)
|
||||||
|
|
|
@ -128,3 +128,51 @@ def test_create_by_cell3():
|
||||||
"h": ScopedValue.create_variable_value(1),
|
"h": ScopedValue.create_variable_value(1),
|
||||||
"cool_boy": ScopedValue.create_naming_value('Naroto'),
|
"cool_boy": ScopedValue.create_naming_value('Naroto'),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_by_cell4():
|
||||||
|
"""
|
||||||
|
Feature: Python api create_call_buildin_op of Node of Rewrite.
|
||||||
|
Description: Call create_call_buildin_op to create a CallCell node.
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
node = Node.create_call_buildin_op(FakeCell3(), None, [ScopedValue.create_naming_value('x')],
|
||||||
|
ScopedValue.create_naming_value('new_conv'),
|
||||||
|
[ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x"),
|
||||||
|
ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x")],
|
||||||
|
{"h": ScopedValue.create_variable_value([1]),
|
||||||
|
"f": ScopedValue.create_variable_value((2,)),
|
||||||
|
"cool_boy": ScopedValue.create_naming_value('Naroto')}, 'new_conv')
|
||||||
|
assert node.get_normalized_args() == {
|
||||||
|
"a": ScopedValue.create_naming_value('x'),
|
||||||
|
"b": ScopedValue.create_naming_value('x'),
|
||||||
|
"args_2": ScopedValue.create_naming_value('x'),
|
||||||
|
"args_3": ScopedValue.create_naming_value('x'),
|
||||||
|
"f": ScopedValue.create_variable_value((2,)),
|
||||||
|
"h": ScopedValue.create_variable_value([1]),
|
||||||
|
"cool_boy": ScopedValue.create_naming_value('Naroto'),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_by_cell5():
|
||||||
|
"""
|
||||||
|
Feature: Python api create_call_buildin_op of Node of Rewrite.
|
||||||
|
Description: Call create_call_buildin_op to create a CallCell node.
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
node = Node.create_call_buildin_op(FakeCell3(), None, [ScopedValue.create_naming_value('x')],
|
||||||
|
ScopedValue.create_naming_value('new_conv'),
|
||||||
|
[ScopedValue.create_variable_value((4,)), ScopedValue.create_variable_value(5),
|
||||||
|
ScopedValue.create_variable_value([5]), ScopedValue.create_naming_value("x")],
|
||||||
|
{"h": ScopedValue.create_variable_value(1),
|
||||||
|
"f": ScopedValue.create_variable_value(2),
|
||||||
|
"cool_boy": ScopedValue.create_naming_value('Naroto')}, 'new_conv')
|
||||||
|
assert node.get_normalized_args() == {
|
||||||
|
"a": ScopedValue.create_variable_value((4,)),
|
||||||
|
"b": ScopedValue.create_variable_value(5),
|
||||||
|
"args_2": ScopedValue.create_variable_value([5]),
|
||||||
|
"args_3": ScopedValue.create_naming_value('x'),
|
||||||
|
"f": ScopedValue.create_variable_value(2),
|
||||||
|
"h": ScopedValue.create_variable_value(1),
|
||||||
|
"cool_boy": ScopedValue.create_naming_value('Naroto'),
|
||||||
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
import os
|
||||||
import ast
|
import ast
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
|
@ -432,3 +433,38 @@ def test_replace_one_to_multi():
|
||||||
assert len(new_relu_node.get_targets()) == 1
|
assert len(new_relu_node.get_targets()) == 1
|
||||||
assert len(relu2.get_normalized_args().values()) == 1
|
assert len(relu2.get_normalized_args().values()) == 1
|
||||||
assert new_relu_node.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
|
assert new_relu_node.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_saved_file_name():
|
||||||
|
"""
|
||||||
|
Feature: Python api set_saved_file_name and get_saved_file_name of SymbolTree of Rewrite.
|
||||||
|
Description: Call set_saved_file_name to set the filename used to save the network.
|
||||||
|
Call get_saved_file_name to get the filename used to save the network.
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
stree, _, _, _ = create_symbol_tree()
|
||||||
|
|
||||||
|
stree.set_saved_file_name("new_network.py")
|
||||||
|
new_file_name = stree.get_saved_file_name()
|
||||||
|
assert new_file_name == "new_network.py"
|
||||||
|
|
||||||
|
stree.set_saved_file_name("new_network_01")
|
||||||
|
new_file_name = stree.get_saved_file_name()
|
||||||
|
assert new_file_name == "new_network_01.py"
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_network_to_file():
|
||||||
|
"""
|
||||||
|
Feature: Python api save_network_to_file of SymbolTree of Rewrite.
|
||||||
|
Description: Call save_network_to_file to save the network to a file.
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
stree, bn, relu1, relu2 = create_symbol_tree()
|
||||||
|
stree.set_node_arg_by_node(relu2, 0, bn)
|
||||||
|
stree.erase_node(relu1)
|
||||||
|
|
||||||
|
stree.set_saved_file_name("new_network.py")
|
||||||
|
stree.save_network_to_file()
|
||||||
|
assert os.path.exists("./new_network.py")
|
||||||
|
|
||||||
|
os.system("rm -f new_network.py")
|
||||||
|
|
Loading…
Reference in New Issue