!15182 [GraphKernel] Add swtich for parallel fusion

From: @tronzhang
Reviewed-by: @gaoxiong1,@dylangeng
Signed-off-by: @dylangeng
This commit is contained in:
mindspore-ci-bot 2021-04-15 09:48:56 +08:00 committed by Gitee
commit 93d905333a
4 changed files with 15 additions and 3 deletions

View File

@ -21,6 +21,7 @@
#include "ir/func_graph.h"
#include "utils/ms_context.h"
#include "utils/context/graph_kernel_flags.h"
#include "backend/optimizer/graph_kernel/add_atomic_clean.h"
#include "backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h"
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
@ -138,7 +139,7 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() {
PassManagerPtr GraphKernelOptimizer::Combine() {
auto pm = std::make_shared<PassManager>("graphkernel_stage6_combine");
// Enable parallel fusion
if (is_gpu) {
if (is_gpu && context::GraphKernelFlags::GetInstance().enable_parallel_fusion) {
// Do parallel fusion for gpu device
pm->AddPass(std::make_shared<ParallelOpFusion>(kGPUDevice, ParallelConfig(7)));
}

View File

@ -157,14 +157,17 @@ void GraphKernelFlags::Refresh() {
void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_map) {
FlagRegister reg(flag_map);
// Boolean flags
reg.AddFlag("dump_as_text", &dump_as_text);
reg.AddFlag("enable_stitch_fusion", &enable_stitch_fusion);
reg.AddFlag("enable_parallel_fusion", &enable_parallel_fusion);
// Integer flags
reg.AddFlag("opt_level", &opt_level);
reg.AddFlag("auto_tune", &auto_tune);
reg.AddFlag("cluster_limit", &cluster_limit);
// String list flags
reg.AddFlag("enable_expand_ops", &enable_expand_ops);
reg.AddFlag("enable_expand_ops_only", &enable_expand_ops_only);
reg.AddFlag("disable_expand_ops", &disable_expand_ops);
@ -177,8 +180,10 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma
std::string GraphKernelFlags::DumpAllFlags() const {
nlohmann::json json;
json["dump_as_text"] = dump_as_text;
json["enable_stitch_fusion"] = enable_stitch_fusion;
json["enable_parallel_fusion"] = enable_parallel_fusion;
json["opt_level"] = opt_level;
json["auto_tune"] = auto_tune;

View File

@ -59,6 +59,11 @@ class GraphKernelFlags {
*/
bool enable_stitch_fusion{false};
/**
* Enable parallel fusion in graph kernel fusion strategy.
*/
bool enable_parallel_fusion{false};
/**
* Optimization level, value from 0 to 3.
* 0: GraphKernel disabled

View File

@ -135,7 +135,8 @@ def _auto_enable_graph_kernel(device_target, graph_kernel_mode):
def _set_graph_kernel_context(device_target, enable_graph_kernel, is_auto_enable_graph_kernel):
if enable_graph_kernel == "true" or is_auto_enable_graph_kernel:
if device_target == 'GPU':
context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_stitch_fusion=true")
context.set_context(enable_graph_kernel=True,
graph_kernel_flags="--enable_stitch_fusion=true --enable_parallel_fusion=true")
else:
logger.warning('Graph kernel only supports GPU back-end now, run with graph kernel off.')