!5736 Add device specific config key checking
Merge pull request !5736 from fary86/add_device_specific_config_check
This commit is contained in:
commit
25a528ae12
|
@ -120,7 +120,8 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|||
engine->IncreaseFunctionCallDepth();
|
||||
if (engine->function_call_depth() > MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
|
||||
MS_LOG(EXCEPTION) << "Exceed function call depth limit "
|
||||
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH) << ".";
|
||||
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)
|
||||
<< ", please call 'context.set_context(max_call_depth=value)' to adjust this value.";
|
||||
}
|
||||
std::vector<AnfNodePtr> nodes = FastShadowSort(func_node);
|
||||
for (auto it = nodes.crbegin(); it != nodes.crend(); it++) {
|
||||
|
|
|
@ -71,40 +71,29 @@ py::object MsCtxGetParameter(const std::shared_ptr<MsContext> &ctx, MsCtxParam p
|
|||
}
|
||||
} // namespace
|
||||
|
||||
// Note: exported python enum variables begining with '_' are for internal use
|
||||
REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
|
||||
(void)py::enum_<MsCtxParam>(*m, "ms_ctx_param", py::arithmetic())
|
||||
.value("enable_auto_mixed_precision", MsCtxParam::MS_CTX_ENABLE_AUTO_MIXED_PRECISION)
|
||||
.value("check_bprop", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG)
|
||||
.value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP)
|
||||
.value("enable_dynamic_mem_pool", MsCtxParam::MS_CTX_ENABLE_DYNAMIC_MEM_POOL)
|
||||
.value("enable_gpu_summary", MsCtxParam::MS_CTX_ENABLE_GPU_SUMMARY)
|
||||
.value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL)
|
||||
.value("enable_hccl", MsCtxParam::MS_CTX_ENABLE_HCCL)
|
||||
.value("enable_mem_reuse", MsCtxParam::MS_CTX_ENABLE_MEM_REUSE)
|
||||
.value("enable_pynative_hook", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_HOOK)
|
||||
.value("enable_pynative_infer", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_INFER)
|
||||
.value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION)
|
||||
.value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE)
|
||||
.value("enable_task_sink", MsCtxParam::MS_CTX_ENABLE_TASK_SINK)
|
||||
.value("ir_fusion_flag", MsCtxParam::MS_CTX_IR_FUSION_FLAG)
|
||||
.value("is_multi_graph_sink", MsCtxParam::MS_CTX_IS_MULTI_GRAPH_SINK)
|
||||
.value("is_pynative_ge_init", MsCtxParam::MS_CTX_IS_PYNATIVE_GE_INIT)
|
||||
.value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY)
|
||||
.value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING)
|
||||
.value("save_graphs", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG)
|
||||
.value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY)
|
||||
.value("mode", MsCtxParam::MS_CTX_EXECUTION_MODE)
|
||||
.value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET)
|
||||
.value("graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE)
|
||||
.value("_graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE)
|
||||
.value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH)
|
||||
.value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS)
|
||||
.value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH)
|
||||
.value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH)
|
||||
.value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE)
|
||||
.value("device_id", MsCtxParam::MS_CTX_DEVICE_ID)
|
||||
.value("ge_ref", MsCtxParam::MS_CTX_GE_REF)
|
||||
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH)
|
||||
.value("tsd_ref", MsCtxParam::MS_CTX_TSD_REF);
|
||||
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH);
|
||||
|
||||
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext")
|
||||
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
|
||||
|
|
|
@ -221,6 +221,7 @@ class _Context:
|
|||
self.set_param(ms_ctx_param.profiling_options, option)
|
||||
|
||||
def set_variable_memory_max_size(self, variable_memory_max_size):
|
||||
"""set values of variable_memory_max_size and graph_memory_max_size"""
|
||||
if not _check_input_format(variable_memory_max_size):
|
||||
raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
|
||||
if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE:
|
||||
|
@ -229,7 +230,8 @@ class _Context:
|
|||
graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2])
|
||||
graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024"
|
||||
self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_)
|
||||
self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_)
|
||||
# pylint: disable=protected-access
|
||||
self.set_param(ms_ctx_param._graph_memory_max_size, graph_memory_max_size_)
|
||||
|
||||
def set_max_device_memory(self, max_device_memory):
|
||||
if not _check_input_format(max_device_memory):
|
||||
|
@ -427,6 +429,26 @@ def reset_auto_parallel_context():
|
|||
_reset_auto_parallel_context()
|
||||
|
||||
|
||||
def _check_target_specific_cfgs(device, arg_key):
|
||||
"""Checking whether a config is sutable for a specified device"""
|
||||
device_cfgs = {
|
||||
'enable_auto_mixed_precision': ['Ascend'],
|
||||
'enable_dump': ['Ascend'],
|
||||
'enable_profiling': ['Ascend'],
|
||||
'variable_memory_max_size': ['Ascend'],
|
||||
'max_device_memory': ['GPU']
|
||||
}
|
||||
# configs not in map device_cfgs are supposed to be suitable for all devices
|
||||
if not arg_key in device_cfgs:
|
||||
return True
|
||||
supported_devices = device_cfgs[arg_key]
|
||||
if device in supported_devices:
|
||||
return True
|
||||
logger.warning(f"Config '{arg_key}' only supports devices in {supported_devices}, current device is '{device}'"
|
||||
", ignore it.")
|
||||
return False
|
||||
|
||||
|
||||
@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
|
||||
save_graphs_path=str, enable_dump=bool,
|
||||
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
|
||||
|
@ -452,6 +474,26 @@ def set_context(**kwargs):
|
|||
The mode is not recommended to be changed after net was initilized because the implementations of some
|
||||
operations are different in graph mode and pynative mode. Default: PYNATIVE_MODE.
|
||||
|
||||
Some configurations are device specific, see the bellow table for details:
|
||||
|
||||
=========================== =========================== =================
|
||||
Common(CPU/GPU/Asecend) Ascend GPU
|
||||
=========================== =========================== =================
|
||||
check_bprop enable_auto_mixed_precision max_device_memory
|
||||
device_id enable_dump
|
||||
device_target enable_profiling
|
||||
enable_graph_kernel variable_memory_max_size
|
||||
enable_reduce_precision
|
||||
enable_sparse
|
||||
mode
|
||||
print_file_path
|
||||
profiling_options
|
||||
reserve_class_name_in_scope
|
||||
save_dump_path
|
||||
save_graphs
|
||||
save_graphs_path
|
||||
=========================== =========================== =================
|
||||
|
||||
Args:
|
||||
mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1).
|
||||
device_target (str): The target device to run, support "Ascend", "GPU", "CPU". Default: "Ascend".
|
||||
|
@ -515,14 +557,21 @@ def set_context(**kwargs):
|
|||
>>> context.set_context(max_call_depth=80)
|
||||
"""
|
||||
ctx = _context()
|
||||
# set device target first
|
||||
if 'device_target' in kwargs:
|
||||
ctx.set_device_target(kwargs['device_target'])
|
||||
device = ctx.get_param(ms_ctx_param.device_target)
|
||||
for key, value in kwargs.items():
|
||||
if not _check_target_specific_cfgs(device, key):
|
||||
continue
|
||||
if hasattr(ctx, key):
|
||||
setattr(ctx, key, value)
|
||||
continue
|
||||
if key in ctx.setters:
|
||||
ctx.setters[key](ctx, value)
|
||||
continue
|
||||
if key in ms_ctx_param.__members__:
|
||||
# enum variables begining with '_' are for internal use
|
||||
if key in ms_ctx_param.__members__ and key[0] != '_':
|
||||
ctx.set_param(ms_ctx_param.__members__[key], value)
|
||||
continue
|
||||
raise ValueError("Set context keyword %s is not recognized!" % key)
|
||||
|
@ -542,9 +591,12 @@ def get_context(attr_key):
|
|||
ValueError: If input key is not an attribute in context.
|
||||
"""
|
||||
ctx = _context()
|
||||
device = ctx.get_param(ms_ctx_param.device_target)
|
||||
_ = _check_target_specific_cfgs(device, attr_key)
|
||||
if hasattr(ctx, attr_key):
|
||||
return getattr(ctx, attr_key)
|
||||
if attr_key in ms_ctx_param.__members__:
|
||||
# enum variables begining with '_' are for internal use
|
||||
if attr_key in ms_ctx_param.__members__ and attr_key[0] != '_':
|
||||
return ctx.get_param(ms_ctx_param.__members__[attr_key])
|
||||
raise ValueError("Get context keyword %s is not recognized!" % attr_key)
|
||||
|
||||
|
|
Loading…
Reference in New Issue