fix rewrite bug
This commit is contained in:
parent
a95fd90b07
commit
fa60ab7c84
|
@ -253,8 +253,29 @@ class SymbolTree:
|
|||
def get_network(self) -> Cell:
|
||||
"""
|
||||
Get modified network.
|
||||
The source code of network is saved to a file, the default file name is `network_define.py`.
|
||||
|
||||
Returns:
|
||||
A network object.
|
||||
"""
|
||||
return self._symbol_tree.get_network()
|
||||
|
||||
def set_saved_file_name(self, file_name: str):
|
||||
"""
|
||||
Set the name of the file used to save the network.
|
||||
|
||||
Args:
|
||||
file_name (str): filename to be set.
|
||||
"""
|
||||
Validator.check_value_type("file_name", file_name, [str], "Saving network")
|
||||
self._symbol_tree.set_saved_file_name(file_name)
|
||||
|
||||
def get_saved_file_name(self):
|
||||
"""Gets the filename used to save the network."""
|
||||
return self._symbol_tree.get_saved_file_name()
|
||||
|
||||
def save_network_to_file(self):
|
||||
"""
|
||||
Save the modified network to a file. Default file name is `network_define.py`.
|
||||
"""
|
||||
self._symbol_tree.save_network_to_file()
|
||||
|
|
|
@ -256,6 +256,75 @@ class AstModifier(ast.NodeTransformer):
|
|||
ast.fix_missing_locations(result)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _create_arg_by_single_value(value: ScopedValue):
|
||||
"""
|
||||
Create an instance of ast.Constant.
|
||||
|
||||
Args:
|
||||
value (ScopedValue): value used to create arg.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if scope of value is not empty.
|
||||
TypeError: type of arg not in [ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue]
|
||||
|
||||
Returns:
|
||||
ast.Constant: An instance of ast.Constant
|
||||
"""
|
||||
if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
|
||||
if value.scope:
|
||||
raise RuntimeError("For arg the scope should be empty")
|
||||
return ast.Constant(value=value.value, kind=None)
|
||||
raise TypeError("Type of arg only support [ValueType.IntValue, ValueType.FloatValue,"
|
||||
f" ValueType.StringValue], but got {type(value)}")
|
||||
|
||||
@staticmethod
|
||||
def _create_list_or_tuple(value: ScopedValue):
|
||||
"""
|
||||
Create an instance of ast.List or ast.Tuple.
|
||||
|
||||
Args:
|
||||
value (ScopedValue): value used to create ast node.
|
||||
|
||||
Returns:
|
||||
ast.List or ast.Tuple: An instance of ast.List or ast.Tuple.
|
||||
"""
|
||||
elts = []
|
||||
for v in value.value:
|
||||
elts.append(AstModifier._create_arg_by_single_value(v))
|
||||
if isinstance(value, list):
|
||||
return ast.List(elts=elts)
|
||||
return ast.Tuple(elts=elts)
|
||||
|
||||
@staticmethod
|
||||
def _create_keyword(arg: str, value: ScopedValue):
|
||||
"""
|
||||
Create an instance of ast.keyword.
|
||||
|
||||
Args:
|
||||
arg (str): key of keyword.
|
||||
value (ScopedValue): value used to create ast.keywrod instance.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if scope of value is not empty.
|
||||
TypeError: type of arg not in [ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue,
|
||||
ValueType.ListValue, ValueType.TupleValue]
|
||||
|
||||
Returns:
|
||||
ast.keyword: a instance of ast.keyword.
|
||||
"""
|
||||
if value.scope:
|
||||
raise RuntimeError("value.scope should be empty")
|
||||
if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
|
||||
v = ast.Constant(value=value.value, kind=None)
|
||||
elif value.type in (ValueType.ListValue, ValueType.TupleValue):
|
||||
v = AstModifier._create_list_or_tuple(value)
|
||||
else:
|
||||
raise TypeError("Type of keyword value only support [ValueType.IntValue, ValueType.FloatValue,"
|
||||
"ValueType.StringValue, ValueType.ListValue, ValueType.TupleValue],"
|
||||
f"but got {type(value)}")
|
||||
return ast.keyword(arg=arg, value=v)
|
||||
|
||||
@staticmethod
|
||||
def _create_call_args(args: [ScopedValue]) -> [ast.AST]:
|
||||
"""
|
||||
|
@ -279,14 +348,14 @@ class AstModifier(ast.NodeTransformer):
|
|||
if not isinstance(arg, ScopedValue):
|
||||
raise TypeError("arg should be ScopedValue, got: ", type(arg))
|
||||
if arg.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
|
||||
if arg.scope:
|
||||
raise RuntimeError("arg.scope should be empty")
|
||||
results.append(ast.Constant(value=arg.value, kind=None))
|
||||
results.append(AstModifier._create_arg_by_single_value(arg))
|
||||
elif arg.type == ValueType.NamingValue:
|
||||
if arg.scope:
|
||||
results.append(ast.Attribute(ast.Name(arg.scope, ast.Load()), arg.value, ast.Store()))
|
||||
else:
|
||||
results.append(ast.Name(arg.value, ast.Store()))
|
||||
elif arg.type == ValueType.ListValue or arg.type == ValueType.TupleValue:
|
||||
results.append(AstModifier._create_list_or_tuple(arg))
|
||||
else:
|
||||
raise RuntimeError("Please handle custom-object first")
|
||||
return results
|
||||
|
@ -313,10 +382,9 @@ class AstModifier(ast.NodeTransformer):
|
|||
for arg, value in kwargs.items():
|
||||
if not isinstance(value, ScopedValue):
|
||||
raise TypeError("value should be ScopedValue, got: ", type(value))
|
||||
if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue):
|
||||
if value.scope:
|
||||
raise RuntimeError("value.scope should be empty")
|
||||
results.append(ast.keyword(arg=arg, value=ast.Constant(value=value.value, kind=None)))
|
||||
if value.type in (ValueType.IntValue, ValueType.FloatValue, ValueType.StringValue,
|
||||
ValueType.ListValue, ValueType.TupleValue):
|
||||
results.append(AstModifier._create_keyword(arg, value))
|
||||
elif value.type == ValueType.NamingValue:
|
||||
if value.scope:
|
||||
results.append(ast.keyword(arg=arg, value=ast.Attribute(ast.Name(value.scope, ast.Load()),
|
||||
|
|
|
@ -13,11 +13,12 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""SymbolTree class define of Rewrite according to forward function of a network."""
|
||||
import stat
|
||||
from typing import Optional, Union, Tuple, Any
|
||||
import os
|
||||
import sys
|
||||
import ast
|
||||
import tempfile
|
||||
import importlib
|
||||
import astunparse
|
||||
|
||||
from mindspore.nn import Cell
|
||||
|
@ -176,6 +177,7 @@ class SymbolTree(Observer, Observable):
|
|||
|
||||
self._tmp_file_limits = 20
|
||||
self._tmp_files = []
|
||||
self._saved_file_name = "./network_define.py"
|
||||
|
||||
def __del__(self):
|
||||
for tmp_file in self._tmp_files:
|
||||
|
@ -923,6 +925,27 @@ class SymbolTree(Observer, Observable):
|
|||
cls = self._get_cls_through_file()
|
||||
return cls(self._global_vars)
|
||||
|
||||
def set_saved_file_name(self, file_name: str):
|
||||
"""Sets the filename used to save the network."""
|
||||
if file_name.endswith(".py"):
|
||||
self._saved_file_name = file_name
|
||||
else:
|
||||
self._saved_file_name = file_name + ".py"
|
||||
|
||||
def get_saved_file_name(self):
|
||||
"""Gets the filename used to save the network."""
|
||||
return self._saved_file_name
|
||||
|
||||
def save_network_to_file(self):
|
||||
"""Save the modified network to a file."""
|
||||
abs_path = os.path.abspath(self._saved_file_name)
|
||||
if os.path.isfile(abs_path):
|
||||
os.remove(abs_path)
|
||||
with os.fdopen(os.open(self._saved_file_name, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
|
||||
source = self.get_code()
|
||||
f.write(source.encode('utf-8'))
|
||||
f.flush()
|
||||
|
||||
def _remove_unused_import(self):
|
||||
"""remove unused import in self._module_ast"""
|
||||
str_checker = StrChecker(self._module_ast)
|
||||
|
@ -1177,19 +1200,12 @@ class SymbolTree(Observer, Observable):
|
|||
Returns:
|
||||
A class handle.
|
||||
"""
|
||||
source = self.get_code()
|
||||
tmp_file = tempfile.NamedTemporaryFile(suffix='.py')
|
||||
tmp_file.write(source.encode('utf8'))
|
||||
tmp_file.flush()
|
||||
tmp_file_name = tmp_file.name
|
||||
if len(self._tmp_files) >= self._tmp_file_limits:
|
||||
raise RuntimeError(f"Too many tmp file generated, it may caused by calling get_network method too much "
|
||||
f"times. Only support open {self._tmp_file_limits} tmp file at most now!")
|
||||
self._tmp_files.append(tmp_file)
|
||||
tmp_module_path, tmp_module_file = os.path.split(tmp_file_name)
|
||||
self.save_network_to_file()
|
||||
tmp_module_path, tmp_module_file = os.path.split(self._saved_file_name)
|
||||
tmp_module_name = tmp_module_file[:-3]
|
||||
sys.path.append(tmp_module_path)
|
||||
tmp_module = __import__(tmp_module_name)
|
||||
tmp_module = importlib.import_module(tmp_module_name)
|
||||
importlib.reload(tmp_module)
|
||||
network_cls = getattr(tmp_module, self._opt_cls_name)
|
||||
if network_cls is None:
|
||||
raise RuntimeError("Can not find network class:", self._opt_cls_name)
|
||||
|
|
|
@ -128,3 +128,51 @@ def test_create_by_cell3():
|
|||
"h": ScopedValue.create_variable_value(1),
|
||||
"cool_boy": ScopedValue.create_naming_value('Naroto'),
|
||||
}
|
||||
|
||||
|
||||
def test_create_by_cell4():
|
||||
"""
|
||||
Feature: Python api create_call_buildin_op of Node of Rewrite.
|
||||
Description: Call create_call_buildin_op to create a CallCell node.
|
||||
Expectation: Success.
|
||||
"""
|
||||
node = Node.create_call_buildin_op(FakeCell3(), None, [ScopedValue.create_naming_value('x')],
|
||||
ScopedValue.create_naming_value('new_conv'),
|
||||
[ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x"),
|
||||
ScopedValue.create_naming_value('x'), ScopedValue.create_naming_value("x")],
|
||||
{"h": ScopedValue.create_variable_value([1]),
|
||||
"f": ScopedValue.create_variable_value((2,)),
|
||||
"cool_boy": ScopedValue.create_naming_value('Naroto')}, 'new_conv')
|
||||
assert node.get_normalized_args() == {
|
||||
"a": ScopedValue.create_naming_value('x'),
|
||||
"b": ScopedValue.create_naming_value('x'),
|
||||
"args_2": ScopedValue.create_naming_value('x'),
|
||||
"args_3": ScopedValue.create_naming_value('x'),
|
||||
"f": ScopedValue.create_variable_value((2,)),
|
||||
"h": ScopedValue.create_variable_value([1]),
|
||||
"cool_boy": ScopedValue.create_naming_value('Naroto'),
|
||||
}
|
||||
|
||||
|
||||
def test_create_by_cell5():
|
||||
"""
|
||||
Feature: Python api create_call_buildin_op of Node of Rewrite.
|
||||
Description: Call create_call_buildin_op to create a CallCell node.
|
||||
Expectation: Success.
|
||||
"""
|
||||
node = Node.create_call_buildin_op(FakeCell3(), None, [ScopedValue.create_naming_value('x')],
|
||||
ScopedValue.create_naming_value('new_conv'),
|
||||
[ScopedValue.create_variable_value((4,)), ScopedValue.create_variable_value(5),
|
||||
ScopedValue.create_variable_value([5]), ScopedValue.create_naming_value("x")],
|
||||
{"h": ScopedValue.create_variable_value(1),
|
||||
"f": ScopedValue.create_variable_value(2),
|
||||
"cool_boy": ScopedValue.create_naming_value('Naroto')}, 'new_conv')
|
||||
assert node.get_normalized_args() == {
|
||||
"a": ScopedValue.create_variable_value((4,)),
|
||||
"b": ScopedValue.create_variable_value(5),
|
||||
"args_2": ScopedValue.create_variable_value([5]),
|
||||
"args_3": ScopedValue.create_naming_value('x'),
|
||||
"f": ScopedValue.create_variable_value(2),
|
||||
"h": ScopedValue.create_variable_value(1),
|
||||
"cool_boy": ScopedValue.create_naming_value('Naroto'),
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import ast
|
||||
import inspect
|
||||
|
||||
|
@ -432,3 +433,38 @@ def test_replace_one_to_multi():
|
|||
assert len(new_relu_node.get_targets()) == 1
|
||||
assert len(relu2.get_normalized_args().values()) == 1
|
||||
assert new_relu_node.get_targets()[0] == list(relu2.get_normalized_args().values())[0]
|
||||
|
||||
|
||||
def test_set_saved_file_name():
|
||||
"""
|
||||
Feature: Python api set_saved_file_name and get_saved_file_name of SymbolTree of Rewrite.
|
||||
Description: Call set_saved_file_name to set the filename used to save the network.
|
||||
Call get_saved_file_name to get the filename used to save the network.
|
||||
Expectation: Success.
|
||||
"""
|
||||
stree, _, _, _ = create_symbol_tree()
|
||||
|
||||
stree.set_saved_file_name("new_network.py")
|
||||
new_file_name = stree.get_saved_file_name()
|
||||
assert new_file_name == "new_network.py"
|
||||
|
||||
stree.set_saved_file_name("new_network_01")
|
||||
new_file_name = stree.get_saved_file_name()
|
||||
assert new_file_name == "new_network_01.py"
|
||||
|
||||
|
||||
def test_save_network_to_file():
|
||||
"""
|
||||
Feature: Python api save_network_to_file of SymbolTree of Rewrite.
|
||||
Description: Call save_network_to_file to save the network to a file.
|
||||
Expectation: Success.
|
||||
"""
|
||||
stree, bn, relu1, relu2 = create_symbol_tree()
|
||||
stree.set_node_arg_by_node(relu2, 0, bn)
|
||||
stree.erase_node(relu1)
|
||||
|
||||
stree.set_saved_file_name("new_network.py")
|
||||
stree.save_network_to_file()
|
||||
assert os.path.exists("./new_network.py")
|
||||
|
||||
os.system("rm -f new_network.py")
|
||||
|
|
Loading…
Reference in New Issue