From 6fdd52080dd1f416da9b881f15e9cbd1829e8190 Mon Sep 17 00:00:00 2001 From: lizhenyu Date: Sat, 29 Aug 2020 13:11:36 +0800 Subject: [PATCH] add mode black list checker --- .../ccsrc/backend/session/gpu_session.cc | 27 +++++++++++++++++-- mindspore/ccsrc/backend/session/gpu_session.h | 5 ++++ .../runtime/device/gpu/kernel_info_setter.cc | 4 +-- .../runtime/device/gpu/kernel_info_setter.h | 2 +- 4 files changed, 33 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 23c29019e5d..b5d466740a6 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -49,9 +49,10 @@ using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; void GPUSession::SelectKernel(const std::shared_ptr &kernel_graph) const { MS_EXCEPTION_IF_NULL(kernel_graph); + bool in_black_list = CheckInModeBlackList(kernel_graph); for (const auto &kernel_node : kernel_graph->execution_order()) { MS_EXCEPTION_IF_NULL(kernel_node); - device::gpu::SetKernelInfo(kernel_node); + device::gpu::SetKernelInfo(kernel_node, in_black_list); } } @@ -75,7 +76,7 @@ void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); - if (context_ptr->execution_mode() != kPynativeMode) { + if (!CheckInModeBlackList(kernel_graph) && context_ptr->execution_mode() != kPynativeMode) { pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); @@ -192,6 +193,28 @@ void GPUSession::Execute(const std::shared_ptr &kernel_graph) const } } +bool GPUSession::CheckInModeBlackList(const std::shared_ptr &kernel_graph) const { + auto kernels = kernel_graph->execution_order(); + size_t conv_cnt = 0; + size_t bn_cnt = 0; + for (const auto &kernel : kernels) { + auto kernel_name = AnfAlgo::GetCNodeName(kernel); + if (kernel_name == prim::kPrimLayerNorm->name()) { + return true; + } + if (kernel_name == prim::kPrimConv2D->name()) { + conv_cnt++; + } + if (kernel_name == prim::kPrimFusedBatchNormEx->name()) { + bn_cnt++; + } + } + if (conv_cnt == kConv2dCount && bn_cnt == kFusedBatchNormCount) { + return true; + } + return false; +} + GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { // Construct graph, if successfully, graph_sum_ + 1 auto graph_id = graph_sum_; diff --git a/mindspore/ccsrc/backend/session/gpu_session.h b/mindspore/ccsrc/backend/session/gpu_session.h index f79ae4e8d56..bd28e97b087 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.h +++ b/mindspore/ccsrc/backend/session/gpu_session.h @@ -67,6 +67,8 @@ class GPUSession : public SessionBasic { void Execute(const std::shared_ptr &kernel_graph) const; + bool CheckInModeBlackList(const std::shared_ptr &kernel_graph) const; + #ifdef ENABLE_DEBUGGER void Dump(const std::shared_ptr &kernel_graph) const; @@ -80,6 +82,9 @@ class GPUSession : public SessionBasic { void PostLoadTensor(const std::shared_ptr &kernel_graph) const; #endif + + static constexpr size_t kConv2dCount = 96; + static constexpr size_t kFusedBatchNormCount = 94; }; using GPUSessionPtr = std::shared_ptr; MS_REG_SESSION(kGPUDevice, GPUSession); diff --git a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc index 2649a43df22..b2801974c85 100644 --- a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc +++ b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc @@ -223,7 +223,7 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector inputs_format; std::vector inputs_type; for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { @@ -237,7 +237,7 @@ void SetKernelInfo(const CNodePtr &kernel_node) { outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); } std::string origin_data_format = kOpFormat_DEFAULT; - if (IsNeedProcessFormatInfo(kernel_node, inputs_type)) { + if (!in_black_list && IsNeedProcessFormatInfo(kernel_node, inputs_type)) { UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format); } std::shared_ptr builder = diff --git a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h index 4cd03e3e306..0f64527b4de 100644 --- a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h +++ b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h @@ -53,7 +53,7 @@ static std::map, std::vector> {prim::kPrimAddN->name(), {{}, {0}}}, }; -void SetKernelInfo(const CNodePtr &apply_kernel_ptr); +void SetKernelInfo(const CNodePtr &kernel_node, bool in_black_list = false); class KernelAttr { public: