forked from mindspore-Ecosystem/mindspore
!15182 [GraphKernel] Add swtich for parallel fusion
From: @tronzhang Reviewed-by: @gaoxiong1,@dylangeng Signed-off-by: @dylangeng
This commit is contained in:
commit
93d905333a
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include "ir/func_graph.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "backend/optimizer/graph_kernel/add_atomic_clean.h"
|
||||
#include "backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h"
|
||||
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
|
||||
|
@ -138,7 +139,7 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() {
|
|||
PassManagerPtr GraphKernelOptimizer::Combine() {
|
||||
auto pm = std::make_shared<PassManager>("graphkernel_stage6_combine");
|
||||
// Enable parallel fusion
|
||||
if (is_gpu) {
|
||||
if (is_gpu && context::GraphKernelFlags::GetInstance().enable_parallel_fusion) {
|
||||
// Do parallel fusion for gpu device
|
||||
pm->AddPass(std::make_shared<ParallelOpFusion>(kGPUDevice, ParallelConfig(7)));
|
||||
}
|
||||
|
|
|
@ -157,14 +157,17 @@ void GraphKernelFlags::Refresh() {
|
|||
void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_map) {
|
||||
FlagRegister reg(flag_map);
|
||||
|
||||
// Boolean flags
|
||||
reg.AddFlag("dump_as_text", &dump_as_text);
|
||||
|
||||
reg.AddFlag("enable_stitch_fusion", &enable_stitch_fusion);
|
||||
reg.AddFlag("enable_parallel_fusion", &enable_parallel_fusion);
|
||||
|
||||
// Integer flags
|
||||
reg.AddFlag("opt_level", &opt_level);
|
||||
reg.AddFlag("auto_tune", &auto_tune);
|
||||
reg.AddFlag("cluster_limit", &cluster_limit);
|
||||
|
||||
// String list flags
|
||||
reg.AddFlag("enable_expand_ops", &enable_expand_ops);
|
||||
reg.AddFlag("enable_expand_ops_only", &enable_expand_ops_only);
|
||||
reg.AddFlag("disable_expand_ops", &disable_expand_ops);
|
||||
|
@ -177,8 +180,10 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma
|
|||
|
||||
std::string GraphKernelFlags::DumpAllFlags() const {
|
||||
nlohmann::json json;
|
||||
|
||||
json["dump_as_text"] = dump_as_text;
|
||||
json["enable_stitch_fusion"] = enable_stitch_fusion;
|
||||
json["enable_parallel_fusion"] = enable_parallel_fusion;
|
||||
|
||||
json["opt_level"] = opt_level;
|
||||
json["auto_tune"] = auto_tune;
|
||||
|
|
|
@ -59,6 +59,11 @@ class GraphKernelFlags {
|
|||
*/
|
||||
bool enable_stitch_fusion{false};
|
||||
|
||||
/**
|
||||
* Enable parallel fusion in graph kernel fusion strategy.
|
||||
*/
|
||||
bool enable_parallel_fusion{false};
|
||||
|
||||
/**
|
||||
* Optimization level, value from 0 to 3.
|
||||
* 0: GraphKernel disabled
|
||||
|
|
|
@ -135,7 +135,8 @@ def _auto_enable_graph_kernel(device_target, graph_kernel_mode):
|
|||
def _set_graph_kernel_context(device_target, enable_graph_kernel, is_auto_enable_graph_kernel):
|
||||
if enable_graph_kernel == "true" or is_auto_enable_graph_kernel:
|
||||
if device_target == 'GPU':
|
||||
context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_stitch_fusion=true")
|
||||
context.set_context(enable_graph_kernel=True,
|
||||
graph_kernel_flags="--enable_stitch_fusion=true --enable_parallel_fusion=true")
|
||||
else:
|
||||
logger.warning('Graph kernel only supports GPU back-end now, run with graph kernel off.')
|
||||
|
||||
|
|
Loading…
Reference in New Issue