From 8ff3c1677845298486e0e82a006d04eb54d2fbda Mon Sep 17 00:00:00 2001 From: tronzhang Date: Wed, 14 Apr 2021 17:36:13 +0800 Subject: [PATCH] add swtich for parallel fusion and default is off --- .../optimizer/graph_kernel/graph_kernel_optimization.cc | 3 ++- mindspore/ccsrc/utils/context/graph_kernel_flags.cc | 7 ++++++- mindspore/ccsrc/utils/context/graph_kernel_flags.h | 5 +++++ model_zoo/official/nlp/bert/run_pretrain.py | 3 ++- 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc index 77140e02615..31f7d8ec1cd 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc @@ -21,6 +21,7 @@ #include "ir/func_graph.h" #include "utils/ms_context.h" +#include "utils/context/graph_kernel_flags.h" #include "backend/optimizer/graph_kernel/add_atomic_clean.h" #include "backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h" #include "backend/optimizer/graph_kernel/arithmetic_simplify.h" @@ -138,7 +139,7 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() { PassManagerPtr GraphKernelOptimizer::Combine() { auto pm = std::make_shared("graphkernel_stage6_combine"); // Enable parallel fusion - if (is_gpu) { + if (is_gpu && context::GraphKernelFlags::GetInstance().enable_parallel_fusion) { // Do parallel fusion for gpu device pm->AddPass(std::make_shared(kGPUDevice, ParallelConfig(7))); } diff --git a/mindspore/ccsrc/utils/context/graph_kernel_flags.cc b/mindspore/ccsrc/utils/context/graph_kernel_flags.cc index b8a7ae72d2c..7bc67993d65 100644 --- a/mindspore/ccsrc/utils/context/graph_kernel_flags.cc +++ b/mindspore/ccsrc/utils/context/graph_kernel_flags.cc @@ -157,14 +157,17 @@ void GraphKernelFlags::Refresh() { void GraphKernelFlags::RegisterFlags(std::map *flag_map) { FlagRegister reg(flag_map); + // Boolean flags reg.AddFlag("dump_as_text", &dump_as_text); - reg.AddFlag("enable_stitch_fusion", &enable_stitch_fusion); + reg.AddFlag("enable_parallel_fusion", &enable_parallel_fusion); + // Integer flags reg.AddFlag("opt_level", &opt_level); reg.AddFlag("auto_tune", &auto_tune); reg.AddFlag("cluster_limit", &cluster_limit); + // String list flags reg.AddFlag("enable_expand_ops", &enable_expand_ops); reg.AddFlag("enable_expand_ops_only", &enable_expand_ops_only); reg.AddFlag("disable_expand_ops", &disable_expand_ops); @@ -177,8 +180,10 @@ void GraphKernelFlags::RegisterFlags(std::map *flag_ma std::string GraphKernelFlags::DumpAllFlags() const { nlohmann::json json; + json["dump_as_text"] = dump_as_text; json["enable_stitch_fusion"] = enable_stitch_fusion; + json["enable_parallel_fusion"] = enable_parallel_fusion; json["opt_level"] = opt_level; json["auto_tune"] = auto_tune; diff --git a/mindspore/ccsrc/utils/context/graph_kernel_flags.h b/mindspore/ccsrc/utils/context/graph_kernel_flags.h index 4b1c037a51e..7d247cea4fe 100644 --- a/mindspore/ccsrc/utils/context/graph_kernel_flags.h +++ b/mindspore/ccsrc/utils/context/graph_kernel_flags.h @@ -59,6 +59,11 @@ class GraphKernelFlags { */ bool enable_stitch_fusion{false}; + /** + * Enable parallel fusion in graph kernel fusion strategy. + */ + bool enable_parallel_fusion{false}; + /** * Optimization level, value from 0 to 3. * 0: GraphKernel disabled diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py index bb5264d127f..1d762daf138 100644 --- a/model_zoo/official/nlp/bert/run_pretrain.py +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -135,7 +135,8 @@ def _auto_enable_graph_kernel(device_target, graph_kernel_mode): def _set_graph_kernel_context(device_target, enable_graph_kernel, is_auto_enable_graph_kernel): if enable_graph_kernel == "true" or is_auto_enable_graph_kernel: if device_target == 'GPU': - context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_stitch_fusion=true") + context.set_context(enable_graph_kernel=True, + graph_kernel_flags="--enable_stitch_fusion=true --enable_parallel_fusion=true") else: logger.warning('Graph kernel only supports GPU back-end now, run with graph kernel off.')