forked from mindspore-Ecosystem/mindspore
Simplify ms_context implementation
This commit is contained in:
parent
d5e02cf474
commit
0a858b3878
|
@ -50,55 +50,6 @@ using ParallelContext = mindspore::parallel::ParallelContext;
|
|||
using CostModelContext = mindspore::parallel::CostModelContext;
|
||||
using mindspore::MsCtxParam;
|
||||
|
||||
namespace mindspore {
|
||||
void MsCtxSetParameter(std::shared_ptr<MsContext> ctx, MsCtxParam param, const py::object &value) {
|
||||
MS_LOG(DEBUG) << "set param(" << param << ") with value '" << py::str(value) << "' of type '"
|
||||
<< py::str(value.get_type()) << "'.";
|
||||
if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END && py::isinstance<py::bool_>(value)) {
|
||||
ctx->set_param<bool>(param, value.cast<bool>());
|
||||
return;
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END && py::isinstance<py::int_>(value)) {
|
||||
ctx->set_param<int>(param, value.cast<int>());
|
||||
return;
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END && py::isinstance<py::int_>(value)) {
|
||||
ctx->set_param<uint32_t>(param, value.cast<uint32_t>());
|
||||
return;
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END && py::isinstance<py::float_>(value)) {
|
||||
ctx->set_param<float>(param, value.cast<float>());
|
||||
return;
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END && py::isinstance<py::str>(value)) {
|
||||
ctx->set_param<std::string>(param, value.cast<std::string>());
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "Got illegal param " << param << " and value with type " << py::str(value.get_type());
|
||||
}
|
||||
|
||||
py::object MsCtxGetParameter(const std::shared_ptr<MsContext> &ctx, MsCtxParam param) {
|
||||
if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END) {
|
||||
return py::bool_(ctx->get_param<bool>(param));
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END) {
|
||||
return py::int_(ctx->get_param<int>(param));
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END) {
|
||||
return py::int_(ctx->get_param<uint32_t>(param));
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END) {
|
||||
return py::float_(ctx->get_param<float>(param));
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END) {
|
||||
return py::str(ctx->get_param<std::string>(param));
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "Got illegal param " << param << ".";
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
||||
// Interface with python
|
||||
PYBIND11_MODULE(_c_expression, m) {
|
||||
m.doc() = "MindSpore c plugin";
|
||||
|
@ -151,49 +102,6 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
|
||||
(void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph.");
|
||||
|
||||
(void)m.def("ms_ctx_get_param", &mindspore::MsCtxGetParameter, "Get value of specified paramter.");
|
||||
(void)m.def("ms_ctx_set_param", &mindspore::MsCtxSetParameter, "Set value for specified paramter.");
|
||||
|
||||
(void)py::enum_<MsCtxParam>(*m, "ms_ctx_param", py::arithmetic())
|
||||
.value("auto_mixed_precision_flag", MsCtxParam::MS_CTX_AUTO_MIXED_PRECISION_FLAG)
|
||||
.value("check_bprop_flag", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG)
|
||||
.value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP)
|
||||
.value("enable_dynamic_mem_pool", MsCtxParam::MS_CTX_ENABLE_DYNAMIC_MEM_POOL)
|
||||
.value("enable_gpu_summary", MsCtxParam::MS_CTX_ENABLE_GPU_SUMMARY)
|
||||
.value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL)
|
||||
.value("enable_hccl", MsCtxParam::MS_CTX_ENABLE_HCCL)
|
||||
.value("enable_loop_sink", MsCtxParam::MS_CTX_ENABLE_LOOP_SINK)
|
||||
.value("enable_mem_reuse", MsCtxParam::MS_CTX_ENABLE_MEM_REUSE)
|
||||
.value("enable_pynative_hook", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_HOOK)
|
||||
.value("enable_pynative_infer", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_INFER)
|
||||
.value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION)
|
||||
.value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE)
|
||||
.value("enable_task_sink", MsCtxParam::MS_CTX_ENABLE_TASK_SINK)
|
||||
.value("ir_fusion_flag", MsCtxParam::MS_CTX_IR_FUSION_FLAG)
|
||||
.value("is_multi_graph_sink", MsCtxParam::MS_CTX_IS_MULTI_GRAPH_SINK)
|
||||
.value("is_pynative_ge_init", MsCtxParam::MS_CTX_IS_PYNATIVE_GE_INIT)
|
||||
.value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY)
|
||||
.value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING)
|
||||
.value("save_graphs_flag", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG)
|
||||
.value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY)
|
||||
.value("execution_mode", MsCtxParam::MS_CTX_EXECUTION_MODE)
|
||||
.value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET)
|
||||
.value("graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE)
|
||||
.value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH)
|
||||
.value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS)
|
||||
.value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH)
|
||||
.value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH)
|
||||
.value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE)
|
||||
.value("device_id", MsCtxParam::MS_CTX_DEVICE_ID)
|
||||
.value("ge_ref", MsCtxParam::MS_CTX_GE_REF)
|
||||
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH)
|
||||
.value("tsd_ref", MsCtxParam::MS_CTX_TSD_REF);
|
||||
|
||||
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(m, "MSContext")
|
||||
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
|
||||
.def("get_backend_policy", &mindspore::MsContext::backend_policy, "Get backend policy.")
|
||||
.def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy.");
|
||||
|
||||
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
|
||||
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
|
||||
.def("get_enable_mpi", &mindspore::MpiConfig::enable_mpi, "Get whether enable mpi.")
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
void MsCtxSetParameter(std::shared_ptr<MsContext> ctx, MsCtxParam param, const py::object &value) {
|
||||
MS_LOG(DEBUG) << "set param(" << param << ") with value '" << py::str(value).cast<std::string>() << "' of type '"
|
||||
<< py::str(value.get_type()).cast<std::string>() << "'.";
|
||||
if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END && py::isinstance<py::bool_>(value)) {
|
||||
ctx->set_param<bool>(param, value.cast<bool>());
|
||||
return;
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END && py::isinstance<py::int_>(value)) {
|
||||
ctx->set_param<int>(param, value.cast<int>());
|
||||
return;
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END && py::isinstance<py::int_>(value)) {
|
||||
ctx->set_param<uint32_t>(param, value.cast<uint32_t>());
|
||||
return;
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END && py::isinstance<py::float_>(value)) {
|
||||
ctx->set_param<float>(param, value.cast<float>());
|
||||
return;
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END && py::isinstance<py::str>(value)) {
|
||||
ctx->set_param<std::string>(param, value.cast<std::string>());
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "Got illegal param " << param << " and value with type "
|
||||
<< py::str(value.get_type()).cast<std::string>();
|
||||
}
|
||||
|
||||
py::object MsCtxGetParameter(const std::shared_ptr<MsContext> &ctx, MsCtxParam param) {
|
||||
if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END) {
|
||||
return py::bool_(ctx->get_param<bool>(param));
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END) {
|
||||
return py::int_(ctx->get_param<int>(param));
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END) {
|
||||
return py::int_(ctx->get_param<uint32_t>(param));
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END) {
|
||||
return py::float_(ctx->get_param<float>(param));
|
||||
}
|
||||
if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END) {
|
||||
return py::str(ctx->get_param<std::string>(param));
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "Got illegal param " << param << ".";
|
||||
}
|
||||
} // namespace
|
||||
|
||||
REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
|
||||
(void)py::enum_<MsCtxParam>(*m, "ms_ctx_param", py::arithmetic())
|
||||
.value("enable_auto_mixed_precision", MsCtxParam::MS_CTX_ENABLE_AUTO_MIXED_PRECISION)
|
||||
.value("check_bprop", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG)
|
||||
.value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP)
|
||||
.value("enable_dynamic_mem_pool", MsCtxParam::MS_CTX_ENABLE_DYNAMIC_MEM_POOL)
|
||||
.value("enable_gpu_summary", MsCtxParam::MS_CTX_ENABLE_GPU_SUMMARY)
|
||||
.value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL)
|
||||
.value("enable_hccl", MsCtxParam::MS_CTX_ENABLE_HCCL)
|
||||
.value("enable_loop_sink", MsCtxParam::MS_CTX_ENABLE_LOOP_SINK)
|
||||
.value("enable_mem_reuse", MsCtxParam::MS_CTX_ENABLE_MEM_REUSE)
|
||||
.value("enable_pynative_hook", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_HOOK)
|
||||
.value("enable_pynative_infer", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_INFER)
|
||||
.value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION)
|
||||
.value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE)
|
||||
.value("enable_task_sink", MsCtxParam::MS_CTX_ENABLE_TASK_SINK)
|
||||
.value("ir_fusion_flag", MsCtxParam::MS_CTX_IR_FUSION_FLAG)
|
||||
.value("is_multi_graph_sink", MsCtxParam::MS_CTX_IS_MULTI_GRAPH_SINK)
|
||||
.value("is_pynative_ge_init", MsCtxParam::MS_CTX_IS_PYNATIVE_GE_INIT)
|
||||
.value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY)
|
||||
.value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING)
|
||||
.value("save_graphs", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG)
|
||||
.value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY)
|
||||
.value("mode", MsCtxParam::MS_CTX_EXECUTION_MODE)
|
||||
.value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET)
|
||||
.value("graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE)
|
||||
.value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH)
|
||||
.value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS)
|
||||
.value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH)
|
||||
.value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH)
|
||||
.value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE)
|
||||
.value("device_id", MsCtxParam::MS_CTX_DEVICE_ID)
|
||||
.value("ge_ref", MsCtxParam::MS_CTX_GE_REF)
|
||||
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH)
|
||||
.value("tsd_ref", MsCtxParam::MS_CTX_TSD_REF);
|
||||
|
||||
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext")
|
||||
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
|
||||
.def("get_param", &mindspore::MsCtxGetParameter, "Get value of specified paramter.")
|
||||
.def("set_param", &mindspore::MsCtxSetParameter, "Set value for specified paramter.")
|
||||
.def("get_backend_policy", &mindspore::MsContext::backend_policy, "Get backend policy.")
|
||||
.def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy.");
|
||||
}));
|
||||
} // namespace mindspore
|
|
@ -225,7 +225,7 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
|
|||
}
|
||||
|
||||
// Enable auto mixed precision according to the context options
|
||||
if (ms_context_ptr->get_param<bool>(MS_CTX_AUTO_MIXED_PRECISION_FLAG)) {
|
||||
if (ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_AUTO_MIXED_PRECISION)) {
|
||||
(*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision";
|
||||
} else {
|
||||
(*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16";
|
||||
|
@ -337,7 +337,7 @@ bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
|
|||
if (ge::GEFinalize() != ge::GRAPH_SUCCESS) {
|
||||
MS_LOG(WARNING) << "Finalize GE failed!";
|
||||
}
|
||||
ms_context_ptr->set_pynative_ge_init(false);
|
||||
ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
|
||||
} else {
|
||||
MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = "
|
||||
<< ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
|
||||
|
|
|
@ -22,7 +22,7 @@ import threading
|
|||
from collections import namedtuple
|
||||
from types import FunctionType
|
||||
from mindspore import log as logger
|
||||
from mindspore._c_expression import MSContext, ms_ctx_param, ms_ctx_get_param, ms_ctx_set_param
|
||||
from mindspore._c_expression import MSContext, ms_ctx_param
|
||||
from mindspore._checkparam import args_type_check
|
||||
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
|
||||
_reset_auto_parallel_context
|
||||
|
@ -158,17 +158,12 @@ class _Context:
|
|||
return value
|
||||
|
||||
def get_param(self, param):
|
||||
return ms_ctx_get_param(self._context_handle, param)
|
||||
return self._context_handle.get_param(param)
|
||||
|
||||
def set_param(self, param, value):
|
||||
ms_ctx_set_param(self._context_handle, param, value)
|
||||
self._context_handle.set_param(param, value)
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
return self.get_param(ms_ctx_param.execution_mode)
|
||||
|
||||
@mode.setter
|
||||
def mode(self, mode):
|
||||
def set_mode(self, mode):
|
||||
"""
|
||||
Switch between Graph mode and PyNative mode.
|
||||
|
||||
|
@ -185,43 +180,17 @@ class _Context:
|
|||
self._context_switches.push(False, None)
|
||||
else:
|
||||
raise ValueError(f'The execution mode {mode} is invalid!')
|
||||
self.set_param(ms_ctx_param.execution_mode, mode)
|
||||
self.set_param(ms_ctx_param.mode, mode)
|
||||
|
||||
def set_backend_policy(self, policy):
|
||||
success = self._context_handle.set_backend_policy(policy)
|
||||
if not success:
|
||||
raise RuntimeError("Backend policy must be one of ge, vm, ms.")
|
||||
|
||||
@property
|
||||
def precompile_only(self):
|
||||
return self.get_param(ms_ctx_param.precompile_only)
|
||||
|
||||
@precompile_only.setter
|
||||
def precompile_only(self, precompile_only):
|
||||
self.set_param(ms_ctx_param.precompile_only, precompile_only)
|
||||
|
||||
@property
|
||||
def save_graphs(self):
|
||||
return self.get_param(ms_ctx_param.save_graphs_flag)
|
||||
|
||||
@save_graphs.setter
|
||||
def save_graphs(self, save_graphs_flag):
|
||||
self.set_param(ms_ctx_param.save_graphs_flag, save_graphs_flag)
|
||||
|
||||
@property
|
||||
def save_graphs_path(self):
|
||||
return self.get_param(ms_ctx_param.save_graphs_path)
|
||||
|
||||
@save_graphs_path.setter
|
||||
def save_graphs_path(self, save_graphs_path):
|
||||
def set_save_graphs_path(self, save_graphs_path):
|
||||
self.set_param(ms_ctx_param.save_graphs_path, _make_directory(save_graphs_path))
|
||||
|
||||
@property
|
||||
def device_target(self):
|
||||
return self.get_param(ms_ctx_param.device_target)
|
||||
|
||||
@device_target.setter
|
||||
def device_target(self, target):
|
||||
def set_device_target(self, target):
|
||||
valid_targets = ["CPU", "GPU", "Ascend", "Davinci"]
|
||||
if not target in valid_targets:
|
||||
raise ValueError(f"Target device name {target} is invalid! It must be one of {valid_targets}")
|
||||
|
@ -231,72 +200,17 @@ class _Context:
|
|||
if self.enable_debug_runtime and target == "CPU":
|
||||
self.set_backend_policy("vm")
|
||||
|
||||
@property
|
||||
def device_id(self):
|
||||
return self.get_param(ms_ctx_param.device_id)
|
||||
|
||||
@device_id.setter
|
||||
def device_id(self, device_id):
|
||||
def set_device_id(self, device_id):
|
||||
if device_id < 0 or device_id > 4095:
|
||||
raise ValueError(f"Device id must be in [0, 4095], but got {device_id}")
|
||||
self.set_param(ms_ctx_param.device_id, device_id)
|
||||
|
||||
@property
|
||||
def max_call_depth(self):
|
||||
return self.get_param(ms_ctx_param.max_call_depth)
|
||||
|
||||
@max_call_depth.setter
|
||||
def max_call_depth(self, max_call_depth):
|
||||
def set_max_call_depth(self, max_call_depth):
|
||||
if max_call_depth <= 0:
|
||||
raise ValueError(f"Max call depth must be greater than 0, but got {max_call_depth}")
|
||||
self.set_param(ms_ctx_param.max_call_depth, max_call_depth)
|
||||
|
||||
@property
|
||||
def enable_auto_mixed_precision(self):
|
||||
return self.get_param(ms_ctx_param.auto_mixed_precision_flag)
|
||||
|
||||
@enable_auto_mixed_precision.setter
|
||||
def enable_auto_mixed_precision(self, enable_auto_mixed_precision):
|
||||
self.set_param(ms_ctx_param.auto_mixed_precision_flag, enable_auto_mixed_precision)
|
||||
|
||||
@property
|
||||
def enable_reduce_precision(self):
|
||||
return self.get_param(ms_ctx_param.enable_reduce_precision_flag)
|
||||
|
||||
@enable_reduce_precision.setter
|
||||
def enable_reduce_precision(self, enable_reduce_precision):
|
||||
self.set_param(ms_ctx_param.enable_reduce_precision_flag, enable_reduce_precision)
|
||||
|
||||
@property
|
||||
def enable_dump(self):
|
||||
return self.get_param(ms_ctx_param.enable_dump)
|
||||
|
||||
@enable_dump.setter
|
||||
def enable_dump(self, enable_dump):
|
||||
self.set_param(ms_ctx_param.enable_dump, enable_dump)
|
||||
|
||||
@property
|
||||
def save_dump_path(self):
|
||||
return self.get_param(ms_ctx_param.save_dump_path)
|
||||
|
||||
@save_dump_path.setter
|
||||
def save_dump_path(self, save_dump_path):
|
||||
self.set_param(ms_ctx_param.save_dump_path, save_dump_path)
|
||||
|
||||
@property
|
||||
def enable_profiling(self):
|
||||
return self.get_param(ms_ctx_param.enable_profiling)
|
||||
|
||||
@enable_profiling.setter
|
||||
def enable_profiling(self, flag):
|
||||
self.set_param(ms_ctx_param.enable_profiling, flag)
|
||||
|
||||
@property
|
||||
def profiling_options(self):
|
||||
return self.get_param(ms_ctx_param.profiling_options)
|
||||
|
||||
@profiling_options.setter
|
||||
def profiling_options(self, option):
|
||||
def set_profiling_options(self, option):
|
||||
options = ["training_trace", "task_trace",
|
||||
"task_trace:training_trace", "training_trace:task_trace", "op_trace"]
|
||||
if option not in options:
|
||||
|
@ -304,30 +218,7 @@ class _Context:
|
|||
"'task_trace:training_trace' 'training_trace:task_trace' or 'op_trace'.")
|
||||
self.set_param(ms_ctx_param.profiling_options, option)
|
||||
|
||||
@property
|
||||
def enable_graph_kernel(self):
|
||||
return self.get_param(ms_ctx_param.enable_graph_kernel)
|
||||
|
||||
@enable_graph_kernel.setter
|
||||
def enable_graph_kernel(self, graph_kernel_switch_):
|
||||
self.set_param(ms_ctx_param.enable_graph_kernel, graph_kernel_switch_)
|
||||
|
||||
@property
|
||||
def reserve_class_name_in_scope(self):
|
||||
"""Gets whether to save the network class name in the scope."""
|
||||
return self._thread_local_info.reserve_class_name_in_scope
|
||||
|
||||
@reserve_class_name_in_scope.setter
|
||||
def reserve_class_name_in_scope(self, reserve_class_name_in_scope):
|
||||
"""Sets whether to save the network class name in the scope."""
|
||||
self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope
|
||||
|
||||
@property
|
||||
def variable_memory_max_size(self):
|
||||
return None
|
||||
|
||||
@variable_memory_max_size.setter
|
||||
def variable_memory_max_size(self, variable_memory_max_size):
|
||||
def set_variable_memory_max_size(self, variable_memory_max_size):
|
||||
if not check_input_format(variable_memory_max_size):
|
||||
raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
|
||||
if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE:
|
||||
|
@ -338,33 +229,7 @@ class _Context:
|
|||
self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_)
|
||||
self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_)
|
||||
|
||||
@property
|
||||
def enable_ge(self):
|
||||
return self._context_handle.get_backend_policy() == 'ge'
|
||||
|
||||
@property
|
||||
def enable_debug_runtime(self):
|
||||
return self._thread_local_info.debug_runtime
|
||||
|
||||
@enable_debug_runtime.setter
|
||||
def enable_debug_runtime(self, enable):
|
||||
thread_info = self._thread_local_info
|
||||
thread_info.debug_runtime = enable
|
||||
|
||||
@property
|
||||
def check_bprop(self):
|
||||
return self.get_param(ms_ctx_param.check_bprop_flag)
|
||||
|
||||
@check_bprop.setter
|
||||
def check_bprop(self, check_bprop_flag):
|
||||
self.set_param(ms_ctx_param.check_bprop_flag, check_bprop_flag)
|
||||
|
||||
@property
|
||||
def max_device_memory(self):
|
||||
return self.get_param(ms_ctx_param.max_device_memory)
|
||||
|
||||
@max_device_memory.setter
|
||||
def max_device_memory(self, max_device_memory):
|
||||
def set_max_device_memory(self, max_device_memory):
|
||||
if not check_input_format(max_device_memory):
|
||||
raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
|
||||
max_device_memory_value = float(max_device_memory[:-2])
|
||||
|
@ -372,12 +237,7 @@ class _Context:
|
|||
raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
|
||||
self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value)
|
||||
|
||||
@property
|
||||
def print_file_path(self):
|
||||
return None
|
||||
|
||||
@print_file_path.setter
|
||||
def print_file_path(self, file_path):
|
||||
def set_print_file_path(self, file_path):
|
||||
"""Add timestamp suffix to file name. Sets print file path."""
|
||||
print_file_path = os.path.realpath(file_path)
|
||||
if os.path.isdir(print_file_path):
|
||||
|
@ -392,13 +252,42 @@ class _Context:
|
|||
full_file_name = print_file_path
|
||||
self.set_param(ms_ctx_param.print_file_path, full_file_name)
|
||||
|
||||
@property
|
||||
def enable_sparse(self):
|
||||
return self.get_param(ms_ctx_param.enable_sparse)
|
||||
setters = {
|
||||
'mode': set_mode,
|
||||
'backend_policy': set_backend_policy,
|
||||
'save_graphs_path': set_save_graphs_path,
|
||||
'device_target': set_device_target,
|
||||
'device_id': set_device_id,
|
||||
'max_call_depth': set_max_call_depth,
|
||||
'profiling_options': set_profiling_options,
|
||||
'variable_memory_max_size': set_variable_memory_max_size,
|
||||
'max_device_memory': set_max_device_memory,
|
||||
'print_file_path': set_print_file_path
|
||||
}
|
||||
|
||||
@property
|
||||
def reserve_class_name_in_scope(self):
|
||||
"""Gets whether to save the network class name in the scope."""
|
||||
return self._thread_local_info.reserve_class_name_in_scope
|
||||
|
||||
@reserve_class_name_in_scope.setter
|
||||
def reserve_class_name_in_scope(self, reserve_class_name_in_scope):
|
||||
"""Sets whether to save the network class name in the scope."""
|
||||
self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope
|
||||
|
||||
@property
|
||||
def enable_ge(self):
|
||||
return self._context_handle.get_backend_policy() == 'ge'
|
||||
|
||||
@property
|
||||
def enable_debug_runtime(self):
|
||||
return self._thread_local_info.debug_runtime
|
||||
|
||||
@enable_debug_runtime.setter
|
||||
def enable_debug_runtime(self, enable):
|
||||
thread_info = self._thread_local_info
|
||||
thread_info.debug_runtime = enable
|
||||
|
||||
@enable_sparse.setter
|
||||
def enable_sparse(self, enable_sparse):
|
||||
self.set_param(ms_ctx_param.enable_sparse, enable_sparse)
|
||||
|
||||
def check_input_format(x):
|
||||
import re
|
||||
|
@ -621,10 +510,18 @@ def set_context(**kwargs):
|
|||
>>> context.set_context(print_file_path="print.pb")
|
||||
>>> context.set_context(max_call_depth=80)
|
||||
"""
|
||||
ctx = _context()
|
||||
for key, value in kwargs.items():
|
||||
if not hasattr(_context(), key):
|
||||
raise ValueError("Set context keyword %s is not recognized!" % key)
|
||||
setattr(_context(), key, value)
|
||||
if hasattr(ctx, key):
|
||||
setattr(ctx, key, value)
|
||||
continue
|
||||
if key in ctx.setters:
|
||||
ctx.setters[key](ctx, value)
|
||||
continue
|
||||
if key in ms_ctx_param.__members__:
|
||||
ctx.set_param(ms_ctx_param.__members__[key], value)
|
||||
continue
|
||||
raise ValueError("Set context keyword %s is not recognized!" % key)
|
||||
|
||||
|
||||
def get_context(attr_key):
|
||||
|
@ -640,10 +537,13 @@ def get_context(attr_key):
|
|||
Raises:
|
||||
ValueError: If input key is not an attribute in context.
|
||||
"""
|
||||
if not hasattr(_context(), attr_key):
|
||||
raise ValueError(
|
||||
"Get context keyword %s is not recognized!" % attr_key)
|
||||
return getattr(_context(), attr_key)
|
||||
ctx = _context()
|
||||
if hasattr(ctx, attr_key):
|
||||
return getattr(ctx, attr_key)
|
||||
if attr_key in ms_ctx_param.__members__:
|
||||
return ctx.get_param(ms_ctx_param.__members__[attr_key])
|
||||
raise ValueError("Get context keyword %s is not recognized!" % attr_key)
|
||||
|
||||
|
||||
class ParallelMode:
|
||||
"""
|
||||
|
|
|
@ -60,7 +60,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
|||
#endif
|
||||
set_param<bool>(MS_CTX_ENABLE_GPU_SUMMARY, true);
|
||||
set_param<bool>(MS_CTX_PRECOMPILE_ONLY, false);
|
||||
set_param<bool>(MS_CTX_AUTO_MIXED_PRECISION_FLAG, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_AUTO_MIXED_PRECISION, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL, true);
|
||||
|
|
|
@ -53,7 +53,7 @@ const float kDefaultMaxDeviceMemory = 1024;
|
|||
enum MsCtxParam : unsigned {
|
||||
// paramater of type bool
|
||||
MS_CTX_TYPE_BOOL_BEGIN,
|
||||
MS_CTX_AUTO_MIXED_PRECISION_FLAG = MS_CTX_TYPE_BOOL_BEGIN,
|
||||
MS_CTX_ENABLE_AUTO_MIXED_PRECISION = MS_CTX_TYPE_BOOL_BEGIN,
|
||||
MS_CTX_CHECK_BPROP_FLAG,
|
||||
MS_CTX_ENABLE_DUMP,
|
||||
MS_CTX_ENABLE_DYNAMIC_MEM_POOL,
|
||||
|
@ -132,22 +132,22 @@ class MsContext {
|
|||
|
||||
template <typename T>
|
||||
void set_param(MsCtxParam param, const T &value) {
|
||||
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
|
||||
MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T &get_param(MsCtxParam param) const {
|
||||
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
|
||||
MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void increase_param(MsCtxParam param) {
|
||||
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
|
||||
MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void decrease_param(MsCtxParam param) {
|
||||
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
|
||||
MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
Loading…
Reference in New Issue