update api doc of save_graphs_path

This commit is contained in:
huanghui 2021-03-02 20:18:55 +08:00
parent 6f834db9b2
commit 1aee8ffb7c
1 changed files with 12 additions and 2 deletions

View File

@ -36,10 +36,11 @@ __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_aut
GRAPH_MODE = 0
PYNATIVE_MODE = 1
_DEVICE_APP_MEMORY_SIZE = 31 # The max memory size of graph plus variable.
_DEVICE_APP_MEMORY_SIZE = 31 # The max memory size of graph plus variable.
_re_pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
_k_context = None
def _make_directory(path):
"""Make directory."""
real_path = None
@ -432,6 +433,7 @@ def set_auto_parallel_context(**kwargs):
"""
_set_auto_parallel_context(**kwargs)
def get_auto_parallel_context(attr_key):
"""
Gets auto parallel context attribute value according to the key.
@ -542,7 +544,13 @@ def set_context(**kwargs):
device_id (int): ID of the target device, the value must be in [0, device_num_per_host-1],
while device_num_per_host should be no more than 4096. Default: 0.
save_graphs (bool): Whether to save graphs. Default: False.
save_graphs_path (str): Path to save graphs. Default: "."
save_graphs_path (str): Path to save graphs. Default: ".".
If the program is executed in the parallel mode, `save_graphs_path` should consist of the path and the
current device id, to ensure that writing file conflicts won't happen when the different processes try to
create the files in the same directory. For example, the `device_id` can be generated by
`device_id = os.getenv("DEVICE_ID")` and the `save_graphs_path` can be set by
`context.set_context(save_graphs_path="path/to/ir/files"+device_id)`.
enable_graph_kernel (bool): Whether to enable composition of basic primitives. These primitives would be
compiled into a fused kernel automatically. Default: False.
reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True.
@ -700,6 +708,7 @@ class ParallelMode:
AUTO_PARALLEL = "auto_parallel"
MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL]
@args_type_check(enable_ps=bool)
def set_ps_context(**kwargs):
"""
@ -750,6 +759,7 @@ def get_ps_context(attr_key):
"""
return _get_ps_context(attr_key)
def reset_ps_context():
"""
Reset parameter server training mode context attributes to the default values: