From e893c70164dc93da3dd26df2aef046f3ddf5943b Mon Sep 17 00:00:00 2001 From: jinyaohui Date: Wed, 24 Jun 2020 17:25:47 +0800 Subject: [PATCH] fix bug --- mindspore/context.py | 3 +++ mindspore/train/serialization.py | 10 ++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/mindspore/context.py b/mindspore/context.py index ad601f8fab9..bf01032bf8d 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -564,6 +564,8 @@ def set_context(**kwargs): check_bprop (bool): Whether to check bprop. Default: False. max_device_memory (str): Sets the maximum memory available for device, currently only supported on GPU. The format is "xxGB". Default: "1024GB". + print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to + a file by default,and turn off printing to the screen. Raises: ValueError: If input key is not an attribute in context. @@ -584,6 +586,7 @@ def set_context(**kwargs): >>> save_graphs_path="/mindspore") >>> context.set_context(enable_profiling=True, profiling_options="training_trace") >>> context.set_context(max_device_memory="3.5GB") + >>> context.set_context(print_file_path="print.pb") """ for key, value in kwargs.items(): if not hasattr(_context(), key): diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index ce776d68212..c4fba9ba8a5 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -29,8 +29,7 @@ from mindspore.common.api import _executor from mindspore.common import dtype as mstype from mindspore._checkparam import check_input_data - -__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"] +__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print"] tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, "Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64, @@ -513,6 +512,13 @@ def parse_print(print_file_name): tensor_list.append(Tensor(param_value, ms_type)) # Scale type else: + data_type_ = data_type.lower() + if 'float' in data_type_: + param_data = float(param_data[0]) + elif 'int' in data_type_: + param_data = int(param_data[0]) + elif 'bool' in data_type_: + param_data = bool(param_data[0]) tensor_list.append(Tensor(param_data, ms_type)) except BaseException as e: