diff --git a/mindspore/context.py b/mindspore/context.py index 78c7d19a626..ba00646a9b0 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -36,10 +36,11 @@ __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_aut GRAPH_MODE = 0 PYNATIVE_MODE = 1 -_DEVICE_APP_MEMORY_SIZE = 31 # The max memory size of graph plus variable. +_DEVICE_APP_MEMORY_SIZE = 31 # The max memory size of graph plus variable. _re_pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB' _k_context = None + def _make_directory(path): """Make directory.""" real_path = None @@ -432,6 +433,7 @@ def set_auto_parallel_context(**kwargs): """ _set_auto_parallel_context(**kwargs) + def get_auto_parallel_context(attr_key): """ Gets auto parallel context attribute value according to the key. @@ -542,7 +544,13 @@ def set_context(**kwargs): device_id (int): ID of the target device, the value must be in [0, device_num_per_host-1], while device_num_per_host should be no more than 4096. Default: 0. save_graphs (bool): Whether to save graphs. Default: False. - save_graphs_path (str): Path to save graphs. Default: "." + save_graphs_path (str): Path to save graphs. Default: ".". + + If the program is executed in the parallel mode, `save_graphs_path` should consist of the path and the + current device id, to ensure that writing file conflicts won't happen when the different processes try to + create the files in the same directory. For example, the `device_id` can be generated by + `device_id = os.getenv("DEVICE_ID")` and the `save_graphs_path` can be set by + `context.set_context(save_graphs_path="path/to/ir/files"+device_id)`. enable_graph_kernel (bool): Whether to enable composition of basic primitives. These primitives would be compiled into a fused kernel automatically. Default: False. reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True. @@ -700,6 +708,7 @@ class ParallelMode: AUTO_PARALLEL = "auto_parallel" MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL] + @args_type_check(enable_ps=bool) def set_ps_context(**kwargs): """ @@ -750,6 +759,7 @@ def get_ps_context(attr_key): """ return _get_ps_context(attr_key) + def reset_ps_context(): """ Reset parameter server training mode context attributes to the default values: