forked from mindspore-Ecosystem/mindspore
!5667 add kernel select after optimize pass
Merge pull request !5667 from zyli2020/code_refactor
This commit is contained in:
commit
bc4c5afc1a
|
@ -96,9 +96,15 @@ class ActivationGradGpuKernel : public GpuKernel {
|
|||
const int split_dim = 4;
|
||||
if (input_shape.size() <= split_dim) {
|
||||
ShapeNdTo4d(input_shape, &shape);
|
||||
if (AnfAlgo::GetInputFormat(kernel_node, 0) == kOpFormat_NHWC) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_,
|
||||
shape[0], shape[3], shape[1], shape[2]),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
} else {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
|
||||
shape[0], shape[1], shape[2], shape[3]),
|
||||
"SetTensor4dDescriptor failed");
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
}
|
||||
} else {
|
||||
CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_);
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ file(GLOB_RECURSE _PREACTIVATE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"common/*.cc"
|
||||
"mem_reuse/*.cc"
|
||||
"pass/*.cc"
|
||||
"gpu/*.cc"
|
||||
)
|
||||
|
||||
if (ENABLE_D)
|
||||
|
@ -10,5 +9,10 @@ if (ENABLE_D)
|
|||
list(APPEND _PREACTIVATE_SRC_LIST ${_D_SRC_LIST})
|
||||
endif ()
|
||||
|
||||
if (ENABLE_GPU)
|
||||
file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc")
|
||||
list(APPEND _PREACTIVATE_SRC_LIST ${_GPU_SRC_LIST})
|
||||
endif ()
|
||||
|
||||
set_property(SOURCE ${_PREACTIVATE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PRE_ACT)
|
||||
add_library(_mindspore_backend_optimizer_obj OBJECT ${_PREACTIVATE_SRC_LIST})
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "runtime/device/gpu/kernel_info_setter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -46,7 +47,7 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
|
|||
auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
|
||||
MS_EXCEPTION_IF_NULL(batch_norm_ex);
|
||||
|
||||
if (AnfAlgo::GetOutputInferDataType(batch_norm_ex, 0) != kNumberTypeFloat16) {
|
||||
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -83,6 +84,7 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
|
|||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->Replace(batch_norm_ex, fused_batch_norm_with_add_relu);
|
||||
device::gpu::SetKernelInfo(fused_batch_norm_with_add_relu);
|
||||
return tuple_get_item;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "runtime/device/gpu/kernel_info_setter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -123,7 +124,8 @@ const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph,
|
|||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (AnfAlgo::GetOutputInferDataType(node, 0) != kNumberTypeFloat16) {
|
||||
|
||||
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -169,6 +171,7 @@ const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph,
|
|||
AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_add_relu_grad);
|
||||
SetShapeAndType(fused_batch_norm_add_relu_grad, node, relu_grad);
|
||||
ReplaceOutput(graph, node, relu_grad, fused_batch_norm_add_relu_grad);
|
||||
device::gpu::SetKernelInfo(fused_batch_norm_add_relu_grad);
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "runtime/device/gpu/kernel_info_setter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -43,7 +44,7 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
|
|||
auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
|
||||
MS_EXCEPTION_IF_NULL(batch_norm_ex);
|
||||
|
||||
if (AnfAlgo::GetOutputInferDataType(batch_norm_ex, 0) != kNumberTypeFloat16) {
|
||||
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -78,6 +79,7 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
|
|||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->Replace(batch_norm_ex, fused_batch_norm_with_relu);
|
||||
device::gpu::SetKernelInfo(fused_batch_norm_with_relu);
|
||||
return tuple_get_item;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "runtime/device/gpu/kernel_info_setter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -38,7 +39,7 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
if (AnfAlgo::GetOutputInferDataType(node, 0) != kNumberTypeFloat16) {
|
||||
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -84,6 +85,7 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con
|
|||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_grad_with_relu.get());
|
||||
AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_grad_with_relu);
|
||||
device::gpu::SetKernelInfo(fused_batch_norm_grad_with_relu);
|
||||
return fused_batch_norm_grad_with_relu;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -53,10 +53,10 @@ using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
|
|||
|
||||
void GPUSession::SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
bool graph_format_transform = IsSupportFormatTransform(kernel_graph);
|
||||
device::gpu::FormatTransformChecker::GetInstance().CheckSupportFormatTransform(kernel_graph);
|
||||
for (const auto &kernel_node : kernel_graph->execution_order()) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
device::gpu::SetKernelInfo(kernel_node, graph_format_transform);
|
||||
device::gpu::SetKernelInfo(kernel_node);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -82,12 +82,6 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
|||
pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
|
||||
if (IsSupportFormatTransform(kernel_graph) && context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
|
||||
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
|
||||
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
|
||||
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
|
||||
// pm->AddPass(std::make_shared<opt::BatchNormAddReluGradFusion>());
|
||||
}
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
|
@ -96,6 +90,10 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
|||
void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
|
||||
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
|
||||
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
|
||||
// pm->AddPass(std::make_shared<opt::BatchNormAddReluGradFusion>());
|
||||
pm->AddPass(std::make_shared<opt::InsertFormatTransformOp>());
|
||||
pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>());
|
||||
pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>());
|
||||
|
@ -201,28 +199,6 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const
|
|||
}
|
||||
}
|
||||
|
||||
bool GPUSession::IsSupportFormatTransform(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
auto kernels = kernel_graph->execution_order();
|
||||
size_t conv_cnt = 0;
|
||||
size_t bn_cnt = 0;
|
||||
for (const auto &kernel : kernels) {
|
||||
auto kernel_name = AnfAlgo::GetCNodeName(kernel);
|
||||
if (kernel_name == prim::kPrimLayerNorm->name()) {
|
||||
return false;
|
||||
}
|
||||
if (kernel_name == prim::kPrimConv2D->name()) {
|
||||
conv_cnt++;
|
||||
}
|
||||
if (kernel_name == prim::kPrimFusedBatchNormEx->name()) {
|
||||
bn_cnt++;
|
||||
}
|
||||
}
|
||||
if (conv_cnt == kConv2dCount && bn_cnt == kFusedBatchNormCount) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
||||
// Construct graph, if successfully, graph_sum_ + 1
|
||||
auto graph_id = graph_sum_;
|
||||
|
@ -232,26 +208,27 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
|
|||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
// Optimize
|
||||
// Dump .pb graph before graph optimization
|
||||
if (save_graphs) {
|
||||
DumpIRProto(graph, "before_opt_" + std::to_string(graph_id));
|
||||
}
|
||||
// Graph optimization irrelevant to device data format
|
||||
Optimize(graph);
|
||||
// Select kernel build info
|
||||
SelectKernel(graph);
|
||||
// Graph optimization relevant to device data format
|
||||
HardwareOptimize(graph);
|
||||
// Dump .pb graph after graph optimization
|
||||
if (save_graphs) {
|
||||
DumpIRProto(graph, "after_opt_" + std::to_string(graph_id));
|
||||
}
|
||||
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
// Assign parameter keys.
|
||||
AssignParamKey(graph);
|
||||
#endif
|
||||
// Start gpu kernel runtime
|
||||
StartKernelRT();
|
||||
// Dump .pb graph before hardware optimization
|
||||
if (save_graphs) {
|
||||
DumpIRProto(graph, "before_hwopt_" + std::to_string(graph_id));
|
||||
}
|
||||
// HardwareOptimize
|
||||
HardwareOptimize(graph);
|
||||
// Dump .pb graph after hardware optimization
|
||||
if (save_graphs) {
|
||||
DumpIRProto(graph, "after_hwopt_" + std::to_string(graph_id));
|
||||
}
|
||||
// Assign CUDA streams
|
||||
AssignStream(graph);
|
||||
// Hide NopOp from execution graph
|
||||
|
|
|
@ -67,8 +67,6 @@ class GPUSession : public SessionBasic {
|
|||
|
||||
void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
||||
bool IsSupportFormatTransform(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
||||
|
@ -82,9 +80,6 @@ class GPUSession : public SessionBasic {
|
|||
|
||||
void PostLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
#endif
|
||||
|
||||
static constexpr size_t kConv2dCount = 96;
|
||||
static constexpr size_t kFusedBatchNormCount = 94;
|
||||
};
|
||||
using GPUSessionPtr = std::shared_ptr<GPUSession>;
|
||||
MS_REG_SESSION(kGPUDevice, GPUSession);
|
||||
|
|
|
@ -47,7 +47,7 @@ bool CudaEnvChecker::CheckNvccInPath() {
|
|||
};
|
||||
|
||||
auto cuda_paths = GetCudaRealPaths();
|
||||
find_nvcc_ = any_of(cuda_paths.begin(), cuda_paths.end(), checker);
|
||||
find_nvcc_ = std::any_of(cuda_paths.begin(), cuda_paths.end(), checker);
|
||||
already_check_nvcc_ = true;
|
||||
return find_nvcc_;
|
||||
}
|
||||
|
|
|
@ -165,6 +165,9 @@ bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<Type
|
|||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
return false;
|
||||
}
|
||||
if (!FormatTransformChecker::GetInstance().format_transform()) {
|
||||
return false;
|
||||
}
|
||||
if (!AnfAlgo::IsRealCNodeKernel(kernel_node)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -232,7 +235,31 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeI
|
|||
}
|
||||
} // namespace
|
||||
|
||||
void SetKernelInfo(const CNodePtr &kernel_node, bool graph_format_transform) {
|
||||
void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
auto kernels = kernel_graph->execution_order();
|
||||
size_t conv_cnt = 0;
|
||||
size_t bn_cnt = 0;
|
||||
for (const auto &kernel : kernels) {
|
||||
auto kernel_name = AnfAlgo::GetCNodeName(kernel);
|
||||
if (kernel_name == prim::kPrimLayerNorm->name()) {
|
||||
format_transform_ = false;
|
||||
return;
|
||||
}
|
||||
if (kernel_name == prim::kPrimConv2D->name()) {
|
||||
conv_cnt++;
|
||||
}
|
||||
if (kernel_name == prim::kPrimFusedBatchNormEx->name()) {
|
||||
bn_cnt++;
|
||||
}
|
||||
}
|
||||
if (conv_cnt == kConv2dCount && bn_cnt == kFusedBatchNormCount) {
|
||||
format_transform_ = false;
|
||||
return;
|
||||
}
|
||||
format_transform_ = true;
|
||||
}
|
||||
|
||||
void SetKernelInfo(const CNodePtr &kernel_node) {
|
||||
std::vector<std::string> inputs_format;
|
||||
std::vector<TypeId> inputs_type;
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
||||
|
@ -246,7 +273,7 @@ void SetKernelInfo(const CNodePtr &kernel_node, bool graph_format_transform) {
|
|||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
|
||||
}
|
||||
std::string origin_data_format = kOpFormat_DEFAULT;
|
||||
if (graph_format_transform && IsNeedProcessFormatInfo(kernel_node, inputs_type)) {
|
||||
if (IsNeedProcessFormatInfo(kernel_node, inputs_type)) {
|
||||
UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format);
|
||||
}
|
||||
std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder =
|
||||
|
|
|
@ -20,11 +20,13 @@
|
|||
#include <utility>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/dtype.h"
|
||||
#include "utils/utils.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
@ -53,7 +55,28 @@ static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>
|
|||
{prim::kPrimAddN->name(), {{}, {0}}},
|
||||
};
|
||||
|
||||
void SetKernelInfo(const CNodePtr &kernel_node, bool graph_format_transform = false);
|
||||
void SetKernelInfo(const CNodePtr &kernel_node);
|
||||
|
||||
class FormatTransformChecker {
|
||||
public:
|
||||
void CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
bool format_transform() const { return format_transform_; }
|
||||
|
||||
static FormatTransformChecker &GetInstance() {
|
||||
static FormatTransformChecker instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
private:
|
||||
FormatTransformChecker() = default;
|
||||
~FormatTransformChecker() = default;
|
||||
FormatTransformChecker(const FormatTransformChecker &);
|
||||
FormatTransformChecker &operator=(const FormatTransformChecker &);
|
||||
|
||||
bool format_transform_{true};
|
||||
static constexpr size_t kConv2dCount = 96;
|
||||
static constexpr size_t kFusedBatchNormCount = 94;
|
||||
};
|
||||
|
||||
class KernelAttr {
|
||||
public:
|
||||
|
|
|
@ -133,6 +133,10 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/
|
|||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/scheduler.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc")
|
||||
|
||||
add_library(_ut_mindspore_obj OBJECT ${MINDSPORE_SRC_LIST})
|
||||
add_library(_ut_ut_obj OBJECT ${UT_SRCS})
|
||||
|
|
Loading…
Reference in New Issue