!37579 add the context disable_format_transform
Merge pull request !37579 from limingqi107/new_actor_runtime
This commit is contained in:
commit
5f72144520
|
@ -64,6 +64,8 @@ mindspore.set_context
|
|||
| | runtime_num_threads | CPU/GPU/Ascend |
|
||||
| +------------------------------+----------------------------+
|
||||
| | compile_cache_path | CPU/GPU/Ascend |
|
||||
| +------------------------------+----------------------------+
|
||||
| | disable_format_transform | GPU |
|
||||
+-------------------------+------------------------------+----------------------------+
|
||||
|
||||
**参数:**
|
||||
|
@ -131,6 +133,7 @@ mindspore.set_context
|
|||
- **enable_compile_cache** (bool) - 表示是否加载或者保存前端编译的图。当 `enable_compile_cache` 被设置为True时,在第一次执行的过程中,一个硬件无关的编译缓存会被生成并且导出为一个MINDIR文件。当该网络被再次执行时,如果 `enable_compile_cache` 仍然为True并且网络脚本没有被更改,那么这个编译缓存会被加载。注意目前只支持有限的Python脚本更改的自动检测,这意味着可能有正确性风险。默认值:False。这是一个实验特性,可能会被更改或者删除。
|
||||
- **compile_cache_path** (str) - 保存前端图编译缓存的路径。默认值:"."。如果目录不存在,系统会自动创建这个目录。缓存会被保存到如下目录: `compile_cache_path/rank_${rank_id}/` 。 `rank_id` 是集群上当前设备的ID。
|
||||
- **runtime_num_threads** (int) - 运行时线程池的线程数控制。 默认值为30。
|
||||
- **disable_format_transform** (bool) - 表示是否取消NCHW到NHWC的自动格式转换功能。当fp16的网络性能不如fp32的时,可以设置 `disable_format_transform` 为True,以尝试提高训练性能。默认值:False。
|
||||
|
||||
**异常:**
|
||||
|
||||
|
|
|
@ -296,11 +296,6 @@ void TransformFormatPosition(std::vector<size_t> *format_position, size_t positi
|
|||
}
|
||||
|
||||
bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type) {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
return false;
|
||||
}
|
||||
if (!FormatTransformChecker::GetInstance().format_transform()) {
|
||||
return false;
|
||||
}
|
||||
|
@ -467,6 +462,18 @@ std::pair<std::string, ExceptionType> PrintUnsupportedTypeWarning(const CNodePtr
|
|||
|
||||
void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<bool>(MS_CTX_DISABLE_FORMAT_TRANSFORM)) {
|
||||
MS_LOG(INFO) << "Disable the automatic format transform function.";
|
||||
format_transform_ = false;
|
||||
return;
|
||||
}
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
format_transform_ = false;
|
||||
return;
|
||||
}
|
||||
|
||||
// TensorCore can be used only in Volta or newer devices.
|
||||
const int marjor_sm = GET_MAJOR_SM;
|
||||
if (marjor_sm < RECOMMEND_SM) {
|
||||
|
|
|
@ -100,7 +100,8 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
|
|||
.value("env_config_path", MsCtxParam::MS_CTX_ENV_CONFIG_PATH)
|
||||
.value("graph_kernel_flags", MsCtxParam::MS_CTX_GRAPH_KERNEL_FLAGS)
|
||||
.value("grad_for_scalar", MsCtxParam::MS_CTX_GRAD_FOR_SCALAR)
|
||||
.value("pynative_synchronize", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE);
|
||||
.value("pynative_synchronize", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE)
|
||||
.value("disable_format_transform", MsCtxParam::MS_CTX_DISABLE_FORMAT_TRANSFORM);
|
||||
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext")
|
||||
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
|
||||
.def("get_param", &mindspore::MsCtxGetParameter, "Get value of specified parameter.")
|
||||
|
|
|
@ -100,6 +100,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
|||
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, true);
|
||||
set_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_RECOVERY, false);
|
||||
set_param<bool>(MS_CTX_DISABLE_FORMAT_TRANSFORM, false);
|
||||
|
||||
uint32_t kDefaultRuntimeNumThreads = 30;
|
||||
uint32_t cpu_core_num = std::thread::hardware_concurrency() - 1;
|
||||
|
|
|
@ -92,6 +92,7 @@ enum MsCtxParam : unsigned {
|
|||
MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE,
|
||||
MS_CTX_ENABLE_MEM_SCHEDULER,
|
||||
MS_CTX_ENABLE_RECOVERY,
|
||||
MS_CTX_DISABLE_FORMAT_TRANSFORM,
|
||||
MS_CTX_TYPE_BOOL_END,
|
||||
|
||||
// parameter of type int
|
||||
|
|
|
@ -627,7 +627,8 @@ def _check_target_specific_cfgs(device, arg_key):
|
|||
'variable_memory_max_size': ['Ascend'],
|
||||
'auto_tune_mode': ['Ascend'],
|
||||
'max_device_memory': ['Ascend', 'GPU'],
|
||||
'mempool_block_size': ['GPU', 'Ascend']
|
||||
'mempool_block_size': ['GPU', 'Ascend'],
|
||||
'disable_format_transform': ['GPU']
|
||||
}
|
||||
# configs not in map device_cfgs are supposed to be suitable for all devices
|
||||
if not arg_key in device_cfgs:
|
||||
|
@ -649,7 +650,7 @@ def _check_target_specific_cfgs(device, arg_key):
|
|||
enable_graph_kernel=bool, reserve_class_name_in_scope=bool, check_bprop=bool,
|
||||
max_device_memory=str, print_file_path=str, max_call_depth=int, env_config_path=str,
|
||||
graph_kernel_flags=str, save_compile_cache=bool, runtime_num_threads=int, load_compile_cache=bool,
|
||||
grad_for_scalar=bool, pynative_synchronize=bool, mempool_block_size=str)
|
||||
grad_for_scalar=bool, pynative_synchronize=bool, mempool_block_size=str, disable_format_transform=bool)
|
||||
def set_context(**kwargs):
|
||||
"""
|
||||
Set context for running environment.
|
||||
|
@ -716,6 +717,8 @@ def set_context(**kwargs):
|
|||
| | runtime_num_threads | CPU/GPU/Ascend |
|
||||
| +------------------------------+----------------------------+
|
||||
| | compile_cache_path | CPU/GPU/Ascend |
|
||||
| +------------------------------+----------------------------+
|
||||
| | disable_format_transform | GPU |
|
||||
+-------------------------+------------------------------+----------------------------+
|
||||
|
||||
Args:
|
||||
|
@ -848,6 +851,9 @@ def set_context(**kwargs):
|
|||
runtime_num_threads(int): The thread pool number of cpu kernel and actor used in runtime,
|
||||
which must bigger than 0. Default value is 30, if you run many processes at
|
||||
the same time, you should set the value smaller to avoid thread contention.
|
||||
disable_format_transform (bool): Whether to disable the automatic format transform function from NCHW to NHWC.
|
||||
When the network training performance of fp16 is worse than fp32,
|
||||
`disable_format_transform` can be set to True to try to improve training performance. Default: False.
|
||||
Raises:
|
||||
ValueError: If input key is not an attribute in context.
|
||||
|
||||
|
@ -874,6 +880,7 @@ def set_context(**kwargs):
|
|||
>>> ms.set_context(enable_compile_cache=True, compile_cache_path="./cache.ms")
|
||||
>>> ms.set_context(pynative_synchronize=True)
|
||||
>>> ms.set_context(runtime_num_threads=10)
|
||||
>>> ms.set_context(disable_format_transform=True)
|
||||
"""
|
||||
ctx = _context()
|
||||
# set device target first
|
||||
|
|
Loading…
Reference in New Issue