!37579 add the context disable_format_transform

Merge pull request !37579 from limingqi107/new_actor_runtime
This commit is contained in:
i-robot 2022-07-08 05:39:45 +00:00 committed by Gitee
commit 5f72144520
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 28 additions and 8 deletions

View File

@ -64,6 +64,8 @@ mindspore.set_context
| | runtime_num_threads | CPU/GPU/Ascend |
| +------------------------------+----------------------------+
| | compile_cache_path | CPU/GPU/Ascend |
| +------------------------------+----------------------------+
| | disable_format_transform | GPU |
+-------------------------+------------------------------+----------------------------+
**参数:**
@ -131,6 +133,7 @@ mindspore.set_context
- **enable_compile_cache** (bool) - 表示是否加载或者保存前端编译的图。当 `enable_compile_cache` 被设置为True时在第一次执行的过程中一个硬件无关的编译缓存会被生成并且导出为一个MINDIR文件。当该网络被再次执行时如果 `enable_compile_cache` 仍然为True并且网络脚本没有被更改那么这个编译缓存会被加载。注意目前只支持有限的Python脚本更改的自动检测这意味着可能有正确性风险。默认值False。这是一个实验特性可能会被更改或者删除。
- **compile_cache_path** (str) - 保存前端图编译缓存的路径。默认值:"."。如果目录不存在,系统会自动创建这个目录。缓存会被保存到如下目录: `compile_cache_path/rank_${rank_id}/``rank_id` 是集群上当前设备的ID。
- **runtime_num_threads** (int) - 运行时线程池的线程数控制。 默认值为30。
- **disable_format_transform** (bool) - 表示是否取消NCHW到NHWC的自动格式转换功能。当fp16的网络性能不如fp32的时可以设置 `disable_format_transform` 为True以尝试提高训练性能。默认值False。
**异常:**

View File

@ -296,11 +296,6 @@ void TransformFormatPosition(std::vector<size_t> *format_position, size_t positi
}
bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return false;
}
if (!FormatTransformChecker::GetInstance().format_transform()) {
return false;
}
@ -467,6 +462,18 @@ std::pair<std::string, ExceptionType> PrintUnsupportedTypeWarning(const CNodePtr
void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<bool>(MS_CTX_DISABLE_FORMAT_TRANSFORM)) {
MS_LOG(INFO) << "Disable the automatic format transform function.";
format_transform_ = false;
return;
}
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
format_transform_ = false;
return;
}
// TensorCore can be used only in Volta or newer devices.
const int marjor_sm = GET_MAJOR_SM;
if (marjor_sm < RECOMMEND_SM) {

View File

@ -100,7 +100,8 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
.value("env_config_path", MsCtxParam::MS_CTX_ENV_CONFIG_PATH)
.value("graph_kernel_flags", MsCtxParam::MS_CTX_GRAPH_KERNEL_FLAGS)
.value("grad_for_scalar", MsCtxParam::MS_CTX_GRAD_FOR_SCALAR)
.value("pynative_synchronize", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE);
.value("pynative_synchronize", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE)
.value("disable_format_transform", MsCtxParam::MS_CTX_DISABLE_FORMAT_TRANSFORM);
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext")
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
.def("get_param", &mindspore::MsCtxGetParameter, "Get value of specified parameter.")

View File

@ -100,6 +100,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, true);
set_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER, false);
set_param<bool>(MS_CTX_ENABLE_RECOVERY, false);
set_param<bool>(MS_CTX_DISABLE_FORMAT_TRANSFORM, false);
uint32_t kDefaultRuntimeNumThreads = 30;
uint32_t cpu_core_num = std::thread::hardware_concurrency() - 1;

View File

@ -92,6 +92,7 @@ enum MsCtxParam : unsigned {
MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE,
MS_CTX_ENABLE_MEM_SCHEDULER,
MS_CTX_ENABLE_RECOVERY,
MS_CTX_DISABLE_FORMAT_TRANSFORM,
MS_CTX_TYPE_BOOL_END,
// parameter of type int

View File

@ -627,7 +627,8 @@ def _check_target_specific_cfgs(device, arg_key):
'variable_memory_max_size': ['Ascend'],
'auto_tune_mode': ['Ascend'],
'max_device_memory': ['Ascend', 'GPU'],
'mempool_block_size': ['GPU', 'Ascend']
'mempool_block_size': ['GPU', 'Ascend'],
'disable_format_transform': ['GPU']
}
# configs not in map device_cfgs are supposed to be suitable for all devices
if not arg_key in device_cfgs:
@ -649,7 +650,7 @@ def _check_target_specific_cfgs(device, arg_key):
enable_graph_kernel=bool, reserve_class_name_in_scope=bool, check_bprop=bool,
max_device_memory=str, print_file_path=str, max_call_depth=int, env_config_path=str,
graph_kernel_flags=str, save_compile_cache=bool, runtime_num_threads=int, load_compile_cache=bool,
grad_for_scalar=bool, pynative_synchronize=bool, mempool_block_size=str)
grad_for_scalar=bool, pynative_synchronize=bool, mempool_block_size=str, disable_format_transform=bool)
def set_context(**kwargs):
"""
Set context for running environment.
@ -716,6 +717,8 @@ def set_context(**kwargs):
| | runtime_num_threads | CPU/GPU/Ascend |
| +------------------------------+----------------------------+
| | compile_cache_path | CPU/GPU/Ascend |
| +------------------------------+----------------------------+
| | disable_format_transform | GPU |
+-------------------------+------------------------------+----------------------------+
Args:
@ -848,6 +851,9 @@ def set_context(**kwargs):
runtime_num_threads(int): The thread pool number of cpu kernel and actor used in runtime,
which must bigger than 0. Default value is 30, if you run many processes at
the same time, you should set the value smaller to avoid thread contention.
disable_format_transform (bool): Whether to disable the automatic format transform function from NCHW to NHWC.
When the network training performance of fp16 is worse than fp32,
`disable_format_transform` can be set to True to try to improve training performance. Default: False.
Raises:
ValueError: If input key is not an attribute in context.
@ -874,6 +880,7 @@ def set_context(**kwargs):
>>> ms.set_context(enable_compile_cache=True, compile_cache_path="./cache.ms")
>>> ms.set_context(pynative_synchronize=True)
>>> ms.set_context(runtime_num_threads=10)
>>> ms.set_context(disable_format_transform=True)
"""
ctx = _context()
# set device target first