!2573 fix print file bug

Merge pull request !2573 from jinyaohui/print
This commit is contained in:
mindspore-ci-bot 2020-06-24 20:27:04 +08:00 committed by Gitee
commit 3c48de8262
2 changed files with 11 additions and 2 deletions

View File

@ -564,6 +564,8 @@ def set_context(**kwargs):
check_bprop (bool): Whether to check bprop. Default: False.
max_device_memory (str): Sets the maximum memory available for device, currently only supported on GPU.
The format is "xxGB". Default: "1024GB".
print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to
a file by defaultand turn off printing to the screen.
Raises:
ValueError: If input key is not an attribute in context.
@ -584,6 +586,7 @@ def set_context(**kwargs):
>>> save_graphs_path="/mindspore")
>>> context.set_context(enable_profiling=True, profiling_options="training_trace")
>>> context.set_context(max_device_memory="3.5GB")
>>> context.set_context(print_file_path="print.pb")
"""
for key, value in kwargs.items():
if not hasattr(_context(), key):

View File

@ -29,8 +29,7 @@ from mindspore.common.api import _executor
from mindspore.common import dtype as mstype
from mindspore._checkparam import check_input_data
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"]
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print"]
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
"Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64,
@ -513,6 +512,13 @@ def parse_print(print_file_name):
tensor_list.append(Tensor(param_value, ms_type))
# Scale type
else:
data_type_ = data_type.lower()
if 'float' in data_type_:
param_data = float(param_data[0])
elif 'int' in data_type_:
param_data = int(param_data[0])
elif 'bool' in data_type_:
param_data = bool(param_data[0])
tensor_list.append(Tensor(param_data, ms_type))
except BaseException as e: