forked from mindspore-Ecosystem/mindspore
commit
3c48de8262
|
@ -564,6 +564,8 @@ def set_context(**kwargs):
|
||||||
check_bprop (bool): Whether to check bprop. Default: False.
|
check_bprop (bool): Whether to check bprop. Default: False.
|
||||||
max_device_memory (str): Sets the maximum memory available for device, currently only supported on GPU.
|
max_device_memory (str): Sets the maximum memory available for device, currently only supported on GPU.
|
||||||
The format is "xxGB". Default: "1024GB".
|
The format is "xxGB". Default: "1024GB".
|
||||||
|
print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to
|
||||||
|
a file by default,and turn off printing to the screen.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If input key is not an attribute in context.
|
ValueError: If input key is not an attribute in context.
|
||||||
|
@ -584,6 +586,7 @@ def set_context(**kwargs):
|
||||||
>>> save_graphs_path="/mindspore")
|
>>> save_graphs_path="/mindspore")
|
||||||
>>> context.set_context(enable_profiling=True, profiling_options="training_trace")
|
>>> context.set_context(enable_profiling=True, profiling_options="training_trace")
|
||||||
>>> context.set_context(max_device_memory="3.5GB")
|
>>> context.set_context(max_device_memory="3.5GB")
|
||||||
|
>>> context.set_context(print_file_path="print.pb")
|
||||||
"""
|
"""
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if not hasattr(_context(), key):
|
if not hasattr(_context(), key):
|
||||||
|
|
|
@ -29,8 +29,7 @@ from mindspore.common.api import _executor
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore._checkparam import check_input_data
|
from mindspore._checkparam import check_input_data
|
||||||
|
|
||||||
|
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print"]
|
||||||
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"]
|
|
||||||
|
|
||||||
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
|
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
|
||||||
"Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64,
|
"Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64,
|
||||||
|
@ -513,6 +512,13 @@ def parse_print(print_file_name):
|
||||||
tensor_list.append(Tensor(param_value, ms_type))
|
tensor_list.append(Tensor(param_value, ms_type))
|
||||||
# Scale type
|
# Scale type
|
||||||
else:
|
else:
|
||||||
|
data_type_ = data_type.lower()
|
||||||
|
if 'float' in data_type_:
|
||||||
|
param_data = float(param_data[0])
|
||||||
|
elif 'int' in data_type_:
|
||||||
|
param_data = int(param_data[0])
|
||||||
|
elif 'bool' in data_type_:
|
||||||
|
param_data = bool(param_data[0])
|
||||||
tensor_list.append(Tensor(param_data, ms_type))
|
tensor_list.append(Tensor(param_data, ms_type))
|
||||||
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
|
|
Loading…
Reference in New Issue