forked from mindspore-Ecosystem/mindspore
update api doc of save_graphs_path
This commit is contained in:
parent
6f834db9b2
commit
1aee8ffb7c
|
@ -36,10 +36,11 @@ __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_aut
|
|||
|
||||
GRAPH_MODE = 0
|
||||
PYNATIVE_MODE = 1
|
||||
_DEVICE_APP_MEMORY_SIZE = 31 # The max memory size of graph plus variable.
|
||||
_DEVICE_APP_MEMORY_SIZE = 31 # The max memory size of graph plus variable.
|
||||
_re_pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
|
||||
_k_context = None
|
||||
|
||||
|
||||
def _make_directory(path):
|
||||
"""Make directory."""
|
||||
real_path = None
|
||||
|
@ -432,6 +433,7 @@ def set_auto_parallel_context(**kwargs):
|
|||
"""
|
||||
_set_auto_parallel_context(**kwargs)
|
||||
|
||||
|
||||
def get_auto_parallel_context(attr_key):
|
||||
"""
|
||||
Gets auto parallel context attribute value according to the key.
|
||||
|
@ -542,7 +544,13 @@ def set_context(**kwargs):
|
|||
device_id (int): ID of the target device, the value must be in [0, device_num_per_host-1],
|
||||
while device_num_per_host should be no more than 4096. Default: 0.
|
||||
save_graphs (bool): Whether to save graphs. Default: False.
|
||||
save_graphs_path (str): Path to save graphs. Default: "."
|
||||
save_graphs_path (str): Path to save graphs. Default: ".".
|
||||
|
||||
If the program is executed in the parallel mode, `save_graphs_path` should consist of the path and the
|
||||
current device id, to ensure that writing file conflicts won't happen when the different processes try to
|
||||
create the files in the same directory. For example, the `device_id` can be generated by
|
||||
`device_id = os.getenv("DEVICE_ID")` and the `save_graphs_path` can be set by
|
||||
`context.set_context(save_graphs_path="path/to/ir/files"+device_id)`.
|
||||
enable_graph_kernel (bool): Whether to enable composition of basic primitives. These primitives would be
|
||||
compiled into a fused kernel automatically. Default: False.
|
||||
reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True.
|
||||
|
@ -700,6 +708,7 @@ class ParallelMode:
|
|||
AUTO_PARALLEL = "auto_parallel"
|
||||
MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL]
|
||||
|
||||
|
||||
@args_type_check(enable_ps=bool)
|
||||
def set_ps_context(**kwargs):
|
||||
"""
|
||||
|
@ -750,6 +759,7 @@ def get_ps_context(attr_key):
|
|||
"""
|
||||
return _get_ps_context(attr_key)
|
||||
|
||||
|
||||
def reset_ps_context():
|
||||
"""
|
||||
Reset parameter server training mode context attributes to the default values:
|
||||
|
|
Loading…
Reference in New Issue