forked from mindspore-Ecosystem/mindspore
fix print file bug
This commit is contained in:
parent
c0e454c07b
commit
c7f6527e92
|
@ -256,6 +256,7 @@ bool SaveDataItem2File(const std::vector<tdt::DataItem> &items, const std::strin
|
|||
if (!print.SerializeToOstream(output)) {
|
||||
MS_LOG(ERROR) << "Save print file:" << print_file_path << " fail.";
|
||||
ret_end_thread = true;
|
||||
break;
|
||||
}
|
||||
print.Clear();
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ The context of mindspore, used to configure the current execution environment,
|
|||
including execution mode, execution backend and other feature switches.
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
from collections import namedtuple
|
||||
from types import FunctionType
|
||||
|
@ -55,12 +56,20 @@ def _make_directory(path):
|
|||
os.makedirs(path)
|
||||
real_path = path
|
||||
except PermissionError as e:
|
||||
logger.error(
|
||||
f"No write permission on the directory `{path}, error = {e}")
|
||||
logger.error(f"No write permission on the directory `{path}, error = {e}")
|
||||
raise ValueError(f"No write permission on the directory `{path}`.")
|
||||
return real_path
|
||||
|
||||
|
||||
def _get_print_file_name(file_name):
|
||||
"""Add timestamp suffix to file name. Rename the file name: file_name + "." + time(seconds)."""
|
||||
time_second = str(int(time.time()))
|
||||
file_name = file_name + "." + time_second
|
||||
if os.path.exists(file_name):
|
||||
ValueError("This file {} already exists.".format(file_name))
|
||||
return file_name
|
||||
|
||||
|
||||
class _ThreadLocalInfo(threading.local):
|
||||
"""
|
||||
Thread local Info used for store thread local attributes.
|
||||
|
@ -381,8 +390,20 @@ class _Context:
|
|||
return None
|
||||
|
||||
@print_file_path.setter
|
||||
def print_file_path(self, file):
|
||||
self._context_handle.set_print_file_path(file)
|
||||
def print_file_path(self, file_path):
|
||||
"""Add timestamp suffix to file name. Sets print file path."""
|
||||
print_file_path = os.path.realpath(file_path)
|
||||
if os.path.isdir(print_file_path):
|
||||
raise IOError("Print_file_path should be file path, but got {}.".format(file_path))
|
||||
|
||||
if os.path.exists(print_file_path):
|
||||
_path, _file_name = os.path.split(print_file_path)
|
||||
path = _make_directory(_path)
|
||||
file_name = _get_print_file_name(_file_name)
|
||||
full_file_name = os.path.join(path, file_name)
|
||||
else:
|
||||
full_file_name = print_file_path
|
||||
self._context_handle.set_print_file_path(full_file_name)
|
||||
|
||||
|
||||
def check_input_format(x):
|
||||
|
@ -575,7 +596,8 @@ def set_context(**kwargs):
|
|||
max_device_memory (str): Sets the maximum memory available for device, currently only supported on GPU.
|
||||
The format is "xxGB". Default: "1024GB".
|
||||
print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to
|
||||
a file by default, and turn off printing to the screen.
|
||||
a file by default, and turn off printing to the screen. If the file already exists, add a timestamp
|
||||
suffix to the file.
|
||||
enable_sparse (bool): Whether to enable sparse feature. Default: False.
|
||||
|
||||
Raises:
|
||||
|
|
|
@ -302,7 +302,7 @@ def _save_graph(network, file_name):
|
|||
if graph_proto:
|
||||
with open(file_name, "wb") as f:
|
||||
f.write(graph_proto)
|
||||
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||
os.chmod(file_name, stat.S_IRUSR)
|
||||
|
||||
|
||||
def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True):
|
||||
|
@ -462,19 +462,18 @@ def parse_print(print_file_name):
|
|||
List, element of list is Tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: Print file is incorrect.
|
||||
ValueError: The print file may be empty, please make sure enter the correct file name.
|
||||
"""
|
||||
if not os.path.realpath(print_file_name):
|
||||
raise ValueError("Please input the correct print file name.")
|
||||
print_file_path = os.path.realpath(print_file_name)
|
||||
|
||||
if os.path.getsize(print_file_name) == 0:
|
||||
if os.path.getsize(print_file_path) == 0:
|
||||
raise ValueError("The print file may be empty, please make sure enter the correct file name.")
|
||||
|
||||
logger.info("Execute load print process.")
|
||||
print_list = Print()
|
||||
|
||||
try:
|
||||
with open(print_file_name, "rb") as f:
|
||||
with open(print_file_path, "rb") as f:
|
||||
pb_content = f.read()
|
||||
print_list.ParseFromString(pb_content)
|
||||
except BaseException as e:
|
||||
|
|
|
@ -118,6 +118,12 @@ def test_variable_memory_max_size():
|
|||
context.set_context(variable_memory_max_size="3GB")
|
||||
|
||||
|
||||
def test_print_file_path():
|
||||
"""test_print_file_path"""
|
||||
with pytest.raises(IOError):
|
||||
context.set_context(print_file_path="./")
|
||||
|
||||
|
||||
def test_set_context():
|
||||
""" test_set_context """
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
|
||||
|
|
|
@ -34,7 +34,7 @@ from mindspore.train.serialization import save_checkpoint, load_checkpoint, load
|
|||
_exec_save_checkpoint, export, _save_graph
|
||||
from ..ut_filter import non_graph_engine
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print.pb")
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
@ -374,10 +374,13 @@ def test_print():
|
|||
|
||||
|
||||
def teardown_module():
|
||||
files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt', 'print.pb']
|
||||
files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt']
|
||||
for item in files:
|
||||
file_name = './' + item
|
||||
if not os.path.exists(file_name):
|
||||
continue
|
||||
os.chmod(file_name, stat.S_IWRITE)
|
||||
os.remove(file_name)
|
||||
import shutil
|
||||
if os.path.exists('./print'):
|
||||
shutil.rmtree('./print')
|
||||
|
|
Loading…
Reference in New Issue