forked from mindspore-Ecosystem/mindspore
!48857 support precision_mode and jit_compile for mindspore.set_context
Merge pull request !48857 from zhangyinxia/master
This commit is contained in:
commit
40b3f79f8b
|
@ -78,6 +78,8 @@ mindspore.set_context
|
|||
| | memory_optimize_level | CPU/GPU/Ascend |
|
||||
| +------------------------------+----------------------------+
|
||||
| | memory_offload | GPU/Ascend |
|
||||
| +------------------------------+----------------------------+
|
||||
| | ascend_config | Ascend |
|
||||
+-------------------------+------------------------------+----------------------------+
|
||||
|
||||
参数:
|
||||
|
@ -167,6 +169,22 @@ mindspore.set_context
|
|||
|
||||
- ON:开启memory offload功能。在Ascend硬件平台,未设置环境变量“GRAPH_OP_RUN=1”时本参数不生效;设置memory_optimize_level='O1'时本参数不生效。
|
||||
- OFF:关闭memory offload功能。
|
||||
- **ascend_config** (dict) - 设置Ascend硬件平台专用的参数,默认不设置。当前只仅支持在Ascend910B硬件平台设置,其他平台不生效。
|
||||
|
||||
- **precision_mode** (str): 混合精度模式设置,Ascend910B硬件平台训练默认值:must_keep_origin_dtype,推理网络默认值:force_fp16。其值范围如下:
|
||||
|
||||
- force_fp16: 当算子既支持float16,又支持float32,直接选择float16.
|
||||
- allow_fp32_to_fp16: 当算子不支持float32数据类型时,直接降低精度float16.
|
||||
- allow_mix_precision: 自动混合精度,针对全网算子,按照内置的优化策略,自动将部分算子的精度降低到float16或bfloat16.
|
||||
- must_keep_origin_dtype: 保持原图精度.
|
||||
- force_fp32: 当算子既支持float16,又支持float32,直接选择float32.
|
||||
- force_lowerprecision: 当算子支持float16或者bfloat16,又支持float32,直接选择float16或者bfloat16.
|
||||
- allow_fp32_to_bf16: 当算子不支持float32数据类型时,直接降低精度到bfloat16.
|
||||
- allow_fp32_to_lowprecision: 当算子不支持float32数据类型时,直接降低精度到float16或者bfloat16.
|
||||
- allow_mix_precision_fp16: 自动混合精度,正对全网算子,按照内置的优化策略,自动将部分算子的精度降低到float16.
|
||||
- allow_mix_precision_bf16: 自动混合精度,正对全网算子,按照内置的优化策略,自动将部分算子的精度降低到bfloat16.
|
||||
|
||||
- **jit_compile** (bool): 表示是否选择在线编译。默认值:True。当设置为False时,优先选择系统中已经编译好的算子二进制文件,提升编译性能。
|
||||
|
||||
异常:
|
||||
- **ValueError** - 输入key不是上下文中的属性。
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include "include/transform/graph_ir/types.h"
|
||||
#include "include/transform/graph_ir/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
@ -50,6 +51,7 @@ constexpr auto kMindsporeDumpConfig = "MINDSPORE_DUMP_CONFIG";
|
|||
constexpr auto kOpDebugConfigFile = "ge_op_debug_config.ini";
|
||||
constexpr char kGeDumpMode[3][7] = {"all", "input", "output"};
|
||||
const std::set<std::string> kIgnoreGEShapeOps = {kSoftMarginLossOpName};
|
||||
const std::set<std::string> kAscend910BVersions = {"Ascend910B1", "Ascend910B2", "Ascend910B3", "Ascend910B4"};
|
||||
|
||||
std::string GetGraphName(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -577,6 +579,32 @@ void UseOpDebugConfig(std::map<std::string, std::string> *ge_options) {
|
|||
MS_LOG(INFO) << "Use MS_COMPILER_OP_DEBUG_CONFIG:" << ge_op_debug_config;
|
||||
}
|
||||
}
|
||||
void GeDeviceContext::SetAscendConfig(const std::shared_ptr<MsContext> &ms_context_ptr,
|
||||
std::map<std::string, std::string> *ge_options) {
|
||||
MS_EXCEPTION_IF_NULL(ms_context_ptr);
|
||||
MS_EXCEPTION_IF_NULL(ge_options);
|
||||
if (ms_context_ptr->get_param<std::string>(MS_CTX_PRECISION_MODE) != "") {
|
||||
(*ge_options)["ge.exec.precision_mode"] = ms_context_ptr->get_param<std::string>(MS_CTX_PRECISION_MODE);
|
||||
MS_LOG(INFO) << "Set precision_mode " << ms_context_ptr->get_param<std::string>(MS_CTX_PRECISION_MODE) << ".";
|
||||
} else if (IsGeTrain()) {
|
||||
auto soc_version = device::ascend::GetSocVersion();
|
||||
if (kAscend910BVersions.count(soc_version) != 0) {
|
||||
(*ge_options)["ge.exec.precision_mode"] = "must_keep_origin_dtype";
|
||||
MS_LOG(INFO) << "Set precision_mode must_keep_origin_dtype. soc_version is " << soc_version;
|
||||
} else {
|
||||
(*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16";
|
||||
MS_LOG(INFO) << "Set precision_mode allow_fp32_to_fp16. soc_version is " << soc_version;
|
||||
}
|
||||
} else {
|
||||
(*ge_options)["ge.exec.precision_mode"] = "force_fp16";
|
||||
}
|
||||
|
||||
if (ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_JIT_COMPILE)) {
|
||||
(*ge_options)["ge.jit_compile"] = "1";
|
||||
} else {
|
||||
(*ge_options)["ge.jit_compile"] = "0";
|
||||
}
|
||||
}
|
||||
|
||||
void GeDeviceContext::GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr,
|
||||
std::map<std::string, std::string> *ge_options) {
|
||||
|
@ -674,11 +702,7 @@ void GeDeviceContext::GetGeOptions(const std::shared_ptr<MsContext> &ms_context_
|
|||
MS_LOG(WARNING) << "Set proto lib path failed!";
|
||||
}
|
||||
|
||||
if (training) {
|
||||
(*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision";
|
||||
} else {
|
||||
(*ge_options)["ge.exec.precision_mode"] = "force_fp16";
|
||||
}
|
||||
SetAscendConfig(ms_context_ptr, ge_options);
|
||||
|
||||
(*ge_options)["ge.enableSmallChannel"] = "1";
|
||||
|
||||
|
|
|
@ -96,6 +96,8 @@ class GeDeviceContext : public DeviceInterface<GeGraphExecutor, GeDeviceResManag
|
|||
bool FinalizeGe(const std::shared_ptr<MsContext> &inst_context);
|
||||
void GetGeOptions(const std::shared_ptr<MsContext> &inst_context, std::map<std::string, std::string> *ge_options);
|
||||
void SetHcclOptions(const std::shared_ptr<MsContext> &inst_context, std::map<std::string, std::string> *ge_options);
|
||||
void SetAscendConfig(const std::shared_ptr<MsContext> &ms_context_ptr,
|
||||
std::map<std::string, std::string> *ge_options);
|
||||
void SetDisableReuseMemoryFlag(std::map<std::string, std::string> *ge_options) const;
|
||||
|
||||
std::unique_ptr<AscendDeprecatedInterface> deprecated_interface_;
|
||||
|
|
|
@ -93,6 +93,8 @@ void RegMsContext(const py::module *m) {
|
|||
.value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS)
|
||||
.value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH)
|
||||
.value("deterministic", MsCtxParam::MS_CTX_DETERMINISTIC)
|
||||
.value("precision_mode", MsCtxParam::MS_CTX_PRECISION_MODE)
|
||||
.value("jit_compile", MsCtxParam::MS_CTX_ENABLE_JIT_COMPILE)
|
||||
.value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH)
|
||||
.value("compile_cache_path", MsCtxParam::MS_CTX_COMPILE_CACHE_PATH)
|
||||
.value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE)
|
||||
|
|
|
@ -64,6 +64,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
|||
set_param<bool>(MS_CTX_ENABLE_DUMP, false);
|
||||
set_param<std::string>(MS_CTX_SAVE_DUMP_PATH, ".");
|
||||
set_param<std::string>(MS_CTX_DETERMINISTIC, "OFF");
|
||||
set_param<std::string>(MS_CTX_PRECISION_MODE, "");
|
||||
set_param<std::string>(MS_CTX_ENV_CONFIG_PATH, "");
|
||||
set_param<std::string>(MS_CTX_TUNE_MODE, "NO_TUNE");
|
||||
set_param<std::string>(MS_CTX_GRAPH_KERNEL_FLAGS, "");
|
||||
|
@ -117,6 +118,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
|||
set_param<bool>(MS_CTX_ENABLE_RECOVERY, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_GE_HETEROGENOUS, false);
|
||||
set_param<bool>(MS_CTX_DISABLE_FORMAT_TRANSFORM, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_JIT_COMPILE, true);
|
||||
set_param<int>(MS_CTX_MEMORY_OPTIMIZE_LEVEL, kOptimizeO0);
|
||||
set_param<uint32_t>(MS_CTX_OP_TIMEOUT, kOpTimeout);
|
||||
|
||||
|
|
|
@ -96,6 +96,7 @@ enum MsCtxParam : unsigned {
|
|||
MS_CTX_ENABLE_RECOVERY,
|
||||
MS_CTX_ENABLE_GE_HETEROGENOUS,
|
||||
MS_CTX_DISABLE_FORMAT_TRANSFORM,
|
||||
MS_CTX_ENABLE_JIT_COMPILE,
|
||||
MS_CTX_TYPE_BOOL_END,
|
||||
|
||||
// parameter of type int
|
||||
|
@ -139,6 +140,7 @@ enum MsCtxParam : unsigned {
|
|||
MS_CTX_GRAPH_KERNEL_FLAGS,
|
||||
MS_CTX_INFER_PRECISION_MODE, // GPU inference precision mode configured by Serving or Unify API.
|
||||
MS_CTX_DETERMINISTIC,
|
||||
MS_CTX_PRECISION_MODE,
|
||||
MS_CTX_TYPE_STRING_END,
|
||||
|
||||
// parameter numbers of each type
|
||||
|
|
|
@ -244,6 +244,35 @@ class _Context:
|
|||
f"{deterministic_options}, but got {deterministic}.")
|
||||
self.set_param(ms_ctx_param.deterministic, deterministic)
|
||||
|
||||
def set_ascend_config(self, ascend_config):
|
||||
"""
|
||||
Enable ascend config.
|
||||
|
||||
Args:
|
||||
ascend_config (dict): 'precision_mode'
|
||||
- precision_mode (str): "force_fp16", "allow_fp32_to_fp16", "allow_mix_precision",
|
||||
"must_keep_origin_dtype", "force_fp32", "force_lowerprecision", "allow_fp32_to_bf16",
|
||||
"allow_fp32_to_lowprecision", "allow_mix_precision_fp16" and "allow_mix_precision_bf16".
|
||||
"""
|
||||
|
||||
ascend_cfgs = {'precision_mode': ["force_fp16", "allow_fp32_to_fp16", "allow_mix_precision",
|
||||
"must_keep_origin_dtype", "force_fp32", "force_lowerprecision",
|
||||
"allow_fp32_to_bf16", "allow_fp32_to_lowprecision",
|
||||
"allow_mix_precision_fp16", "allow_mix_precision_bf16"],
|
||||
'jit_compile': [True, False]}
|
||||
for ascend_key in ascend_config:
|
||||
if ascend_key not in ascend_cfgs:
|
||||
raise ValueError(f"For 'context.set_context', the key of argument 'ascend_config' must be one of "
|
||||
f"{ascend_cfgs}, but got {ascend_key}.")
|
||||
supported_modes = ascend_cfgs.get(ascend_key)
|
||||
if ascend_config[ascend_key] not in supported_modes:
|
||||
raise ValueError(f"For 'ascend_config', the value of argument {ascend_key} must be one of "
|
||||
f"{supported_modes}, but got {ascend_config[ascend_key]}.")
|
||||
if ascend_key == 'precision_mode':
|
||||
self.set_param(ms_ctx_param.precision_mode, ascend_config[ascend_key])
|
||||
if ascend_key == 'jit_compile':
|
||||
self.set_param(ms_ctx_param.jit_compile, ascend_config[ascend_key])
|
||||
|
||||
def set_backend_policy(self, policy):
|
||||
success = self._context_handle.set_backend_policy(policy)
|
||||
if not success:
|
||||
|
@ -421,7 +450,8 @@ class _Context:
|
|||
'memory_optimize_level': set_memory_optimize_level,
|
||||
'op_timeout': set_op_timeout,
|
||||
'memory_offload': set_memory_offload,
|
||||
'deterministic': set_deterministic
|
||||
'deterministic': set_deterministic,
|
||||
'ascend_config': set_ascend_config
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -706,7 +736,8 @@ def _check_target_specific_cfgs(device, arg_key):
|
|||
'auto_tune_mode': ['Ascend'],
|
||||
'max_device_memory': ['Ascend', 'GPU'],
|
||||
'mempool_block_size': ['GPU', 'Ascend'],
|
||||
'disable_format_transform': ['GPU']
|
||||
'disable_format_transform': ['GPU'],
|
||||
'ascend_config': ['Ascend']
|
||||
}
|
||||
# configs not in map device_cfgs are supposed to be suitable for all devices
|
||||
if arg_key not in device_cfgs:
|
||||
|
@ -728,7 +759,7 @@ def _check_target_specific_cfgs(device, arg_key):
|
|||
max_device_memory=str, print_file_path=str, max_call_depth=int, env_config_path=str,
|
||||
graph_kernel_flags=str, save_compile_cache=bool, runtime_num_threads=int, load_compile_cache=bool,
|
||||
grad_for_scalar=bool, pynative_synchronize=bool, mempool_block_size=str, disable_format_transform=bool,
|
||||
op_timeout=int, deterministic=str)
|
||||
op_timeout=int, deterministic=str, ascend_config=dict)
|
||||
def set_context(**kwargs):
|
||||
"""
|
||||
Set context for running environment.
|
||||
|
@ -809,6 +840,8 @@ def set_context(**kwargs):
|
|||
| | memory_optimize_level | CPU/GPU/Ascend |
|
||||
| +------------------------------+----------------------------+
|
||||
| | memory_offload | GPU/Ascend |
|
||||
| +------------------------------+----------------------------+
|
||||
| | ascend_config | Ascend |
|
||||
+-------------------------+------------------------------+----------------------------+
|
||||
|
||||
Args:
|
||||
|
@ -982,7 +1015,33 @@ def set_context(**kwargs):
|
|||
when the environment variable "GRAPH_OP_RUN=1" is not set; This parameter does not take effect when
|
||||
memory_optimize_level is set 'O1'.
|
||||
- OFF: Turn off the memory Offload function.
|
||||
ascend_config (dict): Set the parameters specific to Ascend hardware platform. It is not set by default.
|
||||
Currently, only setting `precision_mode' and jit_compile are supported on Ascend910B hardware platform.
|
||||
|
||||
- precision_mode (str): Mixed precision mode setting, on Ascend910B hardware platform, the default
|
||||
value of training network is must_keep_origin_dtype, and the default value of inference network
|
||||
is force_fp16. The value range is as follows:
|
||||
|
||||
- force_fp16: When the operator supports both float16 and float32, select float16 directly.
|
||||
- allow_fp32_to_fp16: When the operator does not support the float32 data type, directly reduce
|
||||
the precision of float16.
|
||||
- allow_mix_precision: Automatic mixing precision, facing the whole network operator, according
|
||||
to the built-in optimization strategy, automatically reduces the precision of some operators
|
||||
to float16 or bfloat16.
|
||||
- must_keep_origin_dtype: Keep the accuracy of the original drawing.
|
||||
- force_fp32: When the operator supports both float16 and float32, select float32 directly.
|
||||
- force_lowerprecision: When the operator supports both float16 or bfloat16 and float32, select
|
||||
float16 or bfloat16 directly.
|
||||
- allow_fp32_to_bf16: When the operator does not support the float32 data type, directly reduce
|
||||
the precision of bfloat16.
|
||||
- allow_fp32_to_lowprecision: When the operator does not support the float32 data type, directly
|
||||
reduce the precision of float16 or bfloat16.
|
||||
- allow_mix_precision_fp16: Automatic mixing precision, facing the whole network operator, automatically
|
||||
reduces the precision of some operators to float16 according to the built-in optimization strategy.
|
||||
- allow_mix_precision_bf16: Automatic mixing precision, facing the whole network operator, according to
|
||||
the built-in optimization strategy, automatically reduces the precision of some operators to bfloat16.
|
||||
|
||||
- jit_compile (bool): Whether to select online compilation. Default: True.
|
||||
Raises:
|
||||
ValueError: If input key is not an attribute in context.
|
||||
|
||||
|
@ -1014,6 +1073,7 @@ def set_context(**kwargs):
|
|||
>>> ms.set_context(memory_optimize_level='O0')
|
||||
>>> ms.set_context(memory_offload='ON')
|
||||
>>> ms.set_context(deterministic='ON')
|
||||
>>> ms.set_context(ascend_config={"precision_mode": "force_fp16", "jit_compile": True})
|
||||
"""
|
||||
ctx = _context()
|
||||
# set device target first
|
||||
|
@ -1029,6 +1089,8 @@ def set_context(**kwargs):
|
|||
logger.warning(f"For 'context.set_context', '{key}' parameter is deprecated. "
|
||||
"For details, please see the interface parameter API comments")
|
||||
continue
|
||||
if key in ('precision_mode', 'jit_compile'):
|
||||
raise ValueError(f"Please set '{key}' through parameter ascend_config")
|
||||
if key == 'save_graphs':
|
||||
if value is True:
|
||||
value = 2
|
||||
|
|
|
@ -99,6 +99,31 @@ def test_max_device_memory_size():
|
|||
context.set_context(max_device_memory="3.5G")
|
||||
context.set_context.__wrapped__(max_device_memory="3GB")
|
||||
|
||||
|
||||
def test_ascend_config():
|
||||
""""
|
||||
Feature: test_ascend_config function
|
||||
Description: Test case for simplest ascend_config
|
||||
Expectation: The results are as expected
|
||||
"""
|
||||
context.set_context(device_target="Ascend")
|
||||
with pytest.raises(ValueError):
|
||||
context.set_context(precision_mode="force_fp16")
|
||||
with pytest.raises(ValueError):
|
||||
context.set_context(jit_compile=True)
|
||||
with pytest.raises(ValueError):
|
||||
context.set_context(ascend_config={"precision_mode": "xxx"})
|
||||
with pytest.raises(ValueError):
|
||||
context.set_context(ascend_config={"xxxx": 1})
|
||||
with pytest.raises(ValueError):
|
||||
context.set_context(ascend_config={"jit_compile": "xxx"})
|
||||
with pytest.raises(ValueError):
|
||||
context.set_context(ascend_config={"jit_compile": 2})
|
||||
with pytest.raises(ValueError):
|
||||
context.set_context(ascend_config={"precision_mode": 2})
|
||||
context.set_context.__wrapped__(ascend_config={"precision_mode": "force_fp16", "jit_compile": True})
|
||||
|
||||
|
||||
def test_print_file_path():
|
||||
"""test_print_file_path"""
|
||||
with pytest.raises(IOError):
|
||||
|
|
Loading…
Reference in New Issue