!5667 add kernel select after optimize pass

Merge pull request !5667 from zyli2020/code_refactor
This commit is contained in:
mindspore-ci-bot 2020-09-04 14:24:44 +08:00 committed by Gitee
commit bc4c5afc1a
12 changed files with 103 additions and 58 deletions

View File

@ -96,9 +96,15 @@ class ActivationGradGpuKernel : public GpuKernel {
const int split_dim = 4;
if (input_shape.size() <= split_dim) {
ShapeNdTo4d(input_shape, &shape);
if (AnfAlgo::GetInputFormat(kernel_node, 0) == kOpFormat_NHWC) {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_,
shape[0], shape[3], shape[1], shape[2]),
"cudnnSetTensor4dDescriptor failed");
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
shape[0], shape[1], shape[2], shape[3]),
"SetTensor4dDescriptor failed");
"cudnnSetTensor4dDescriptor failed");
}
} else {
CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_);
}

View File

@ -2,7 +2,6 @@ file(GLOB_RECURSE _PREACTIVATE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"common/*.cc"
"mem_reuse/*.cc"
"pass/*.cc"
"gpu/*.cc"
)
if (ENABLE_D)
@ -10,5 +9,10 @@ if (ENABLE_D)
list(APPEND _PREACTIVATE_SRC_LIST ${_D_SRC_LIST})
endif ()
if (ENABLE_GPU)
file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc")
list(APPEND _PREACTIVATE_SRC_LIST ${_GPU_SRC_LIST})
endif ()
set_property(SOURCE ${_PREACTIVATE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PRE_ACT)
add_library(_mindspore_backend_optimizer_obj OBJECT ${_PREACTIVATE_SRC_LIST})

View File

@ -23,6 +23,7 @@
#include "ir/primitive.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/gpu/kernel_info_setter.h"
namespace mindspore {
namespace opt {
@ -46,7 +47,7 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
MS_EXCEPTION_IF_NULL(batch_norm_ex);
if (AnfAlgo::GetOutputInferDataType(batch_norm_ex, 0) != kNumberTypeFloat16) {
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC) {
return nullptr;
}
@ -83,6 +84,7 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->Replace(batch_norm_ex, fused_batch_norm_with_add_relu);
device::gpu::SetKernelInfo(fused_batch_norm_with_add_relu);
return tuple_get_item;
}
} // namespace opt

View File

@ -24,6 +24,7 @@
#include "ir/primitive.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/gpu/kernel_info_setter.h"
namespace mindspore {
namespace opt {
@ -123,7 +124,8 @@ const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::GetOutputInferDataType(node, 0) != kNumberTypeFloat16) {
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC) {
return nullptr;
}
@ -169,6 +171,7 @@ const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph,
AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_add_relu_grad);
SetShapeAndType(fused_batch_norm_add_relu_grad, node, relu_grad);
ReplaceOutput(graph, node, relu_grad, fused_batch_norm_add_relu_grad);
device::gpu::SetKernelInfo(fused_batch_norm_add_relu_grad);
return nullptr;
}
} // namespace opt

View File

@ -23,6 +23,7 @@
#include "ir/primitive.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/gpu/kernel_info_setter.h"
namespace mindspore {
namespace opt {
@ -43,7 +44,7 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
MS_EXCEPTION_IF_NULL(batch_norm_ex);
if (AnfAlgo::GetOutputInferDataType(batch_norm_ex, 0) != kNumberTypeFloat16) {
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC) {
return nullptr;
}
@ -78,6 +79,7 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->Replace(batch_norm_ex, fused_batch_norm_with_relu);
device::gpu::SetKernelInfo(fused_batch_norm_with_relu);
return tuple_get_item;
}
} // namespace opt

View File

@ -23,6 +23,7 @@
#include "ir/primitive.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/gpu/kernel_info_setter.h"
namespace mindspore {
namespace opt {
@ -38,7 +39,7 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::GetOutputInferDataType(node, 0) != kNumberTypeFloat16) {
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC) {
return nullptr;
}
@ -84,6 +85,7 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con
}
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_grad_with_relu.get());
AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_grad_with_relu);
device::gpu::SetKernelInfo(fused_batch_norm_grad_with_relu);
return fused_batch_norm_grad_with_relu;
}
} // namespace opt

View File

@ -53,10 +53,10 @@ using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
void GPUSession::SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_EXCEPTION_IF_NULL(kernel_graph);
bool graph_format_transform = IsSupportFormatTransform(kernel_graph);
device::gpu::FormatTransformChecker::GetInstance().CheckSupportFormatTransform(kernel_graph);
for (const auto &kernel_node : kernel_graph->execution_order()) {
MS_EXCEPTION_IF_NULL(kernel_node);
device::gpu::SetKernelInfo(kernel_node, graph_format_transform);
device::gpu::SetKernelInfo(kernel_node);
}
}
@ -82,12 +82,6 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
if (IsSupportFormatTransform(kernel_graph) && context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
// pm->AddPass(std::make_shared<opt::BatchNormAddReluGradFusion>());
}
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
@ -96,6 +90,10 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
// pm->AddPass(std::make_shared<opt::BatchNormAddReluGradFusion>());
pm->AddPass(std::make_shared<opt::InsertFormatTransformOp>());
pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>());
pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>());
@ -201,28 +199,6 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const
}
}
bool GPUSession::IsSupportFormatTransform(const std::shared_ptr<KernelGraph> &kernel_graph) const {
auto kernels = kernel_graph->execution_order();
size_t conv_cnt = 0;
size_t bn_cnt = 0;
for (const auto &kernel : kernels) {
auto kernel_name = AnfAlgo::GetCNodeName(kernel);
if (kernel_name == prim::kPrimLayerNorm->name()) {
return false;
}
if (kernel_name == prim::kPrimConv2D->name()) {
conv_cnt++;
}
if (kernel_name == prim::kPrimFusedBatchNormEx->name()) {
bn_cnt++;
}
}
if (conv_cnt == kConv2dCount && bn_cnt == kFusedBatchNormCount) {
return false;
}
return true;
}
GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
// Construct graph, if successfully, graph_sum_ + 1
auto graph_id = graph_sum_;
@ -232,26 +208,27 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
// Optimize
// Dump .pb graph before graph optimization
if (save_graphs) {
DumpIRProto(graph, "before_opt_" + std::to_string(graph_id));
}
// Graph optimization irrelevant to device data format
Optimize(graph);
// Select kernel build info
SelectKernel(graph);
// Graph optimization relevant to device data format
HardwareOptimize(graph);
// Dump .pb graph after graph optimization
if (save_graphs) {
DumpIRProto(graph, "after_opt_" + std::to_string(graph_id));
}
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
// Assign parameter keys.
AssignParamKey(graph);
#endif
// Start gpu kernel runtime
StartKernelRT();
// Dump .pb graph before hardware optimization
if (save_graphs) {
DumpIRProto(graph, "before_hwopt_" + std::to_string(graph_id));
}
// HardwareOptimize
HardwareOptimize(graph);
// Dump .pb graph after hardware optimization
if (save_graphs) {
DumpIRProto(graph, "after_hwopt_" + std::to_string(graph_id));
}
// Assign CUDA streams
AssignStream(graph);
// Hide NopOp from execution graph

View File

@ -67,8 +67,6 @@ class GPUSession : public SessionBasic {
void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const;
bool IsSupportFormatTransform(const std::shared_ptr<KernelGraph> &kernel_graph) const;
#ifdef ENABLE_DEBUGGER
void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const;
@ -82,9 +80,6 @@ class GPUSession : public SessionBasic {
void PostLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const;
#endif
static constexpr size_t kConv2dCount = 96;
static constexpr size_t kFusedBatchNormCount = 94;
};
using GPUSessionPtr = std::shared_ptr<GPUSession>;
MS_REG_SESSION(kGPUDevice, GPUSession);

View File

@ -47,7 +47,7 @@ bool CudaEnvChecker::CheckNvccInPath() {
};
auto cuda_paths = GetCudaRealPaths();
find_nvcc_ = any_of(cuda_paths.begin(), cuda_paths.end(), checker);
find_nvcc_ = std::any_of(cuda_paths.begin(), cuda_paths.end(), checker);
already_check_nvcc_ = true;
return find_nvcc_;
}

View File

@ -165,6 +165,9 @@ bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<Type
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return false;
}
if (!FormatTransformChecker::GetInstance().format_transform()) {
return false;
}
if (!AnfAlgo::IsRealCNodeKernel(kernel_node)) {
return false;
}
@ -232,7 +235,31 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeI
}
} // namespace
void SetKernelInfo(const CNodePtr &kernel_node, bool graph_format_transform) {
void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto kernels = kernel_graph->execution_order();
size_t conv_cnt = 0;
size_t bn_cnt = 0;
for (const auto &kernel : kernels) {
auto kernel_name = AnfAlgo::GetCNodeName(kernel);
if (kernel_name == prim::kPrimLayerNorm->name()) {
format_transform_ = false;
return;
}
if (kernel_name == prim::kPrimConv2D->name()) {
conv_cnt++;
}
if (kernel_name == prim::kPrimFusedBatchNormEx->name()) {
bn_cnt++;
}
}
if (conv_cnt == kConv2dCount && bn_cnt == kFusedBatchNormCount) {
format_transform_ = false;
return;
}
format_transform_ = true;
}
void SetKernelInfo(const CNodePtr &kernel_node) {
std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_type;
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
@ -246,7 +273,7 @@ void SetKernelInfo(const CNodePtr &kernel_node, bool graph_format_transform) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
}
std::string origin_data_format = kOpFormat_DEFAULT;
if (graph_format_transform && IsNeedProcessFormatInfo(kernel_node, inputs_type)) {
if (IsNeedProcessFormatInfo(kernel_node, inputs_type)) {
UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format);
}
std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder =

View File

@ -20,11 +20,13 @@
#include <utility>
#include <string>
#include <vector>
#include <memory>
#include <map>
#include "ir/anf.h"
#include "ir/dtype.h"
#include "utils/utils.h"
#include "frontend/operator/ops.h"
#include "backend/session/kernel_graph.h"
namespace mindspore {
namespace device {
@ -53,7 +55,28 @@ static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>
{prim::kPrimAddN->name(), {{}, {0}}},
};
void SetKernelInfo(const CNodePtr &kernel_node, bool graph_format_transform = false);
void SetKernelInfo(const CNodePtr &kernel_node);
class FormatTransformChecker {
public:
void CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph);
bool format_transform() const { return format_transform_; }
static FormatTransformChecker &GetInstance() {
static FormatTransformChecker instance;
return instance;
}
private:
FormatTransformChecker() = default;
~FormatTransformChecker() = default;
FormatTransformChecker(const FormatTransformChecker &);
FormatTransformChecker &operator=(const FormatTransformChecker &);
bool format_transform_{true};
static constexpr size_t kConv2dCount = 96;
static constexpr size_t kFusedBatchNormCount = 94;
};
class KernelAttr {
public:

View File

@ -133,6 +133,10 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/scheduler.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc")
add_library(_ut_mindspore_obj OBJECT ${MINDSPORE_SRC_LIST})
add_library(_ut_ut_obj OBJECT ${UT_SRCS})