forked from mindspore-Ecosystem/mindspore
!5667 add kernel select after optimize pass
Merge pull request !5667 from zyli2020/code_refactor
This commit is contained in:
commit
bc4c5afc1a
|
@ -96,9 +96,15 @@ class ActivationGradGpuKernel : public GpuKernel {
|
||||||
const int split_dim = 4;
|
const int split_dim = 4;
|
||||||
if (input_shape.size() <= split_dim) {
|
if (input_shape.size() <= split_dim) {
|
||||||
ShapeNdTo4d(input_shape, &shape);
|
ShapeNdTo4d(input_shape, &shape);
|
||||||
|
if (AnfAlgo::GetInputFormat(kernel_node, 0) == kOpFormat_NHWC) {
|
||||||
|
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_,
|
||||||
|
shape[0], shape[3], shape[1], shape[2]),
|
||||||
|
"cudnnSetTensor4dDescriptor failed");
|
||||||
|
} else {
|
||||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
|
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
|
||||||
shape[0], shape[1], shape[2], shape[3]),
|
shape[0], shape[1], shape[2], shape[3]),
|
||||||
"SetTensor4dDescriptor failed");
|
"cudnnSetTensor4dDescriptor failed");
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_);
|
CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_);
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,6 @@ file(GLOB_RECURSE _PREACTIVATE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
"common/*.cc"
|
"common/*.cc"
|
||||||
"mem_reuse/*.cc"
|
"mem_reuse/*.cc"
|
||||||
"pass/*.cc"
|
"pass/*.cc"
|
||||||
"gpu/*.cc"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if (ENABLE_D)
|
if (ENABLE_D)
|
||||||
|
@ -10,5 +9,10 @@ if (ENABLE_D)
|
||||||
list(APPEND _PREACTIVATE_SRC_LIST ${_D_SRC_LIST})
|
list(APPEND _PREACTIVATE_SRC_LIST ${_D_SRC_LIST})
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
|
if (ENABLE_GPU)
|
||||||
|
file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc")
|
||||||
|
list(APPEND _PREACTIVATE_SRC_LIST ${_GPU_SRC_LIST})
|
||||||
|
endif ()
|
||||||
|
|
||||||
set_property(SOURCE ${_PREACTIVATE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PRE_ACT)
|
set_property(SOURCE ${_PREACTIVATE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PRE_ACT)
|
||||||
add_library(_mindspore_backend_optimizer_obj OBJECT ${_PREACTIVATE_SRC_LIST})
|
add_library(_mindspore_backend_optimizer_obj OBJECT ${_PREACTIVATE_SRC_LIST})
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "ir/primitive.h"
|
#include "ir/primitive.h"
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
#include "backend/optimizer/common/helper.h"
|
#include "backend/optimizer/common/helper.h"
|
||||||
|
#include "runtime/device/gpu/kernel_info_setter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -46,7 +47,7 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
|
||||||
auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
|
auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
|
||||||
MS_EXCEPTION_IF_NULL(batch_norm_ex);
|
MS_EXCEPTION_IF_NULL(batch_norm_ex);
|
||||||
|
|
||||||
if (AnfAlgo::GetOutputInferDataType(batch_norm_ex, 0) != kNumberTypeFloat16) {
|
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,6 +84,7 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
|
||||||
auto manager = graph->manager();
|
auto manager = graph->manager();
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
manager->Replace(batch_norm_ex, fused_batch_norm_with_add_relu);
|
manager->Replace(batch_norm_ex, fused_batch_norm_with_add_relu);
|
||||||
|
device::gpu::SetKernelInfo(fused_batch_norm_with_add_relu);
|
||||||
return tuple_get_item;
|
return tuple_get_item;
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include "ir/primitive.h"
|
#include "ir/primitive.h"
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
#include "backend/optimizer/common/helper.h"
|
#include "backend/optimizer/common/helper.h"
|
||||||
|
#include "runtime/device/gpu/kernel_info_setter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -123,7 +124,8 @@ const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph,
|
||||||
const EquivPtr &) const {
|
const EquivPtr &) const {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
if (AnfAlgo::GetOutputInferDataType(node, 0) != kNumberTypeFloat16) {
|
|
||||||
|
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,6 +171,7 @@ const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph,
|
||||||
AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_add_relu_grad);
|
AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_add_relu_grad);
|
||||||
SetShapeAndType(fused_batch_norm_add_relu_grad, node, relu_grad);
|
SetShapeAndType(fused_batch_norm_add_relu_grad, node, relu_grad);
|
||||||
ReplaceOutput(graph, node, relu_grad, fused_batch_norm_add_relu_grad);
|
ReplaceOutput(graph, node, relu_grad, fused_batch_norm_add_relu_grad);
|
||||||
|
device::gpu::SetKernelInfo(fused_batch_norm_add_relu_grad);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "ir/primitive.h"
|
#include "ir/primitive.h"
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
#include "backend/optimizer/common/helper.h"
|
#include "backend/optimizer/common/helper.h"
|
||||||
|
#include "runtime/device/gpu/kernel_info_setter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -43,7 +44,7 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
|
||||||
auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
|
auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
|
||||||
MS_EXCEPTION_IF_NULL(batch_norm_ex);
|
MS_EXCEPTION_IF_NULL(batch_norm_ex);
|
||||||
|
|
||||||
if (AnfAlgo::GetOutputInferDataType(batch_norm_ex, 0) != kNumberTypeFloat16) {
|
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,6 +79,7 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
|
||||||
auto manager = graph->manager();
|
auto manager = graph->manager();
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
manager->Replace(batch_norm_ex, fused_batch_norm_with_relu);
|
manager->Replace(batch_norm_ex, fused_batch_norm_with_relu);
|
||||||
|
device::gpu::SetKernelInfo(fused_batch_norm_with_relu);
|
||||||
return tuple_get_item;
|
return tuple_get_item;
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "ir/primitive.h"
|
#include "ir/primitive.h"
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
#include "backend/optimizer/common/helper.h"
|
#include "backend/optimizer/common/helper.h"
|
||||||
|
#include "runtime/device/gpu/kernel_info_setter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -38,7 +39,7 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
|
||||||
if (AnfAlgo::GetOutputInferDataType(node, 0) != kNumberTypeFloat16) {
|
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,6 +85,7 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con
|
||||||
}
|
}
|
||||||
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_grad_with_relu.get());
|
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_grad_with_relu.get());
|
||||||
AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_grad_with_relu);
|
AnfAlgo::CopyNodeAttrs(node, fused_batch_norm_grad_with_relu);
|
||||||
|
device::gpu::SetKernelInfo(fused_batch_norm_grad_with_relu);
|
||||||
return fused_batch_norm_grad_with_relu;
|
return fused_batch_norm_grad_with_relu;
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
|
|
|
@ -53,10 +53,10 @@ using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
|
||||||
|
|
||||||
void GPUSession::SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
void GPUSession::SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
bool graph_format_transform = IsSupportFormatTransform(kernel_graph);
|
device::gpu::FormatTransformChecker::GetInstance().CheckSupportFormatTransform(kernel_graph);
|
||||||
for (const auto &kernel_node : kernel_graph->execution_order()) {
|
for (const auto &kernel_node : kernel_graph->execution_order()) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
device::gpu::SetKernelInfo(kernel_node, graph_format_transform);
|
device::gpu::SetKernelInfo(kernel_node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,12 +82,6 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||||
pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>());
|
pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>());
|
||||||
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
|
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
|
||||||
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
|
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
|
||||||
if (IsSupportFormatTransform(kernel_graph) && context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
|
|
||||||
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
|
|
||||||
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
|
|
||||||
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
|
|
||||||
// pm->AddPass(std::make_shared<opt::BatchNormAddReluGradFusion>());
|
|
||||||
}
|
|
||||||
optimizer->AddPassManager(pm);
|
optimizer->AddPassManager(pm);
|
||||||
(void)optimizer->Optimize(kernel_graph);
|
(void)optimizer->Optimize(kernel_graph);
|
||||||
kernel_graph->SetExecOrderByDefault();
|
kernel_graph->SetExecOrderByDefault();
|
||||||
|
@ -96,6 +90,10 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||||
void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
auto pm = std::make_shared<opt::PassManager>();
|
auto pm = std::make_shared<opt::PassManager>();
|
||||||
|
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
|
||||||
|
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
|
||||||
|
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
|
||||||
|
// pm->AddPass(std::make_shared<opt::BatchNormAddReluGradFusion>());
|
||||||
pm->AddPass(std::make_shared<opt::InsertFormatTransformOp>());
|
pm->AddPass(std::make_shared<opt::InsertFormatTransformOp>());
|
||||||
pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>());
|
pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>());
|
||||||
pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>());
|
pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>());
|
||||||
|
@ -201,28 +199,6 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GPUSession::IsSupportFormatTransform(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
|
||||||
auto kernels = kernel_graph->execution_order();
|
|
||||||
size_t conv_cnt = 0;
|
|
||||||
size_t bn_cnt = 0;
|
|
||||||
for (const auto &kernel : kernels) {
|
|
||||||
auto kernel_name = AnfAlgo::GetCNodeName(kernel);
|
|
||||||
if (kernel_name == prim::kPrimLayerNorm->name()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (kernel_name == prim::kPrimConv2D->name()) {
|
|
||||||
conv_cnt++;
|
|
||||||
}
|
|
||||||
if (kernel_name == prim::kPrimFusedBatchNormEx->name()) {
|
|
||||||
bn_cnt++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (conv_cnt == kConv2dCount && bn_cnt == kFusedBatchNormCount) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
||||||
// Construct graph, if successfully, graph_sum_ + 1
|
// Construct graph, if successfully, graph_sum_ + 1
|
||||||
auto graph_id = graph_sum_;
|
auto graph_id = graph_sum_;
|
||||||
|
@ -232,26 +208,27 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
|
||||||
auto context_ptr = MsContext::GetInstance();
|
auto context_ptr = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||||
// Optimize
|
// Dump .pb graph before graph optimization
|
||||||
|
if (save_graphs) {
|
||||||
|
DumpIRProto(graph, "before_opt_" + std::to_string(graph_id));
|
||||||
|
}
|
||||||
|
// Graph optimization irrelevant to device data format
|
||||||
Optimize(graph);
|
Optimize(graph);
|
||||||
// Select kernel build info
|
// Select kernel build info
|
||||||
SelectKernel(graph);
|
SelectKernel(graph);
|
||||||
|
// Graph optimization relevant to device data format
|
||||||
|
HardwareOptimize(graph);
|
||||||
|
// Dump .pb graph after graph optimization
|
||||||
|
if (save_graphs) {
|
||||||
|
DumpIRProto(graph, "after_opt_" + std::to_string(graph_id));
|
||||||
|
}
|
||||||
|
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||||
// Assign parameter keys.
|
// Assign parameter keys.
|
||||||
AssignParamKey(graph);
|
AssignParamKey(graph);
|
||||||
#endif
|
#endif
|
||||||
// Start gpu kernel runtime
|
// Start gpu kernel runtime
|
||||||
StartKernelRT();
|
StartKernelRT();
|
||||||
// Dump .pb graph before hardware optimization
|
|
||||||
if (save_graphs) {
|
|
||||||
DumpIRProto(graph, "before_hwopt_" + std::to_string(graph_id));
|
|
||||||
}
|
|
||||||
// HardwareOptimize
|
|
||||||
HardwareOptimize(graph);
|
|
||||||
// Dump .pb graph after hardware optimization
|
|
||||||
if (save_graphs) {
|
|
||||||
DumpIRProto(graph, "after_hwopt_" + std::to_string(graph_id));
|
|
||||||
}
|
|
||||||
// Assign CUDA streams
|
// Assign CUDA streams
|
||||||
AssignStream(graph);
|
AssignStream(graph);
|
||||||
// Hide NopOp from execution graph
|
// Hide NopOp from execution graph
|
||||||
|
|
|
@ -67,8 +67,6 @@ class GPUSession : public SessionBasic {
|
||||||
|
|
||||||
void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||||
|
|
||||||
bool IsSupportFormatTransform(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
|
||||||
|
|
||||||
#ifdef ENABLE_DEBUGGER
|
#ifdef ENABLE_DEBUGGER
|
||||||
void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||||
|
|
||||||
|
@ -82,9 +80,6 @@ class GPUSession : public SessionBasic {
|
||||||
|
|
||||||
void PostLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
void PostLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static constexpr size_t kConv2dCount = 96;
|
|
||||||
static constexpr size_t kFusedBatchNormCount = 94;
|
|
||||||
};
|
};
|
||||||
using GPUSessionPtr = std::shared_ptr<GPUSession>;
|
using GPUSessionPtr = std::shared_ptr<GPUSession>;
|
||||||
MS_REG_SESSION(kGPUDevice, GPUSession);
|
MS_REG_SESSION(kGPUDevice, GPUSession);
|
||||||
|
|
|
@ -47,7 +47,7 @@ bool CudaEnvChecker::CheckNvccInPath() {
|
||||||
};
|
};
|
||||||
|
|
||||||
auto cuda_paths = GetCudaRealPaths();
|
auto cuda_paths = GetCudaRealPaths();
|
||||||
find_nvcc_ = any_of(cuda_paths.begin(), cuda_paths.end(), checker);
|
find_nvcc_ = std::any_of(cuda_paths.begin(), cuda_paths.end(), checker);
|
||||||
already_check_nvcc_ = true;
|
already_check_nvcc_ = true;
|
||||||
return find_nvcc_;
|
return find_nvcc_;
|
||||||
}
|
}
|
||||||
|
|
|
@ -165,6 +165,9 @@ bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<Type
|
||||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (!FormatTransformChecker::GetInstance().format_transform()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
if (!AnfAlgo::IsRealCNodeKernel(kernel_node)) {
|
if (!AnfAlgo::IsRealCNodeKernel(kernel_node)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -232,7 +235,31 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeI
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void SetKernelInfo(const CNodePtr &kernel_node, bool graph_format_transform) {
|
void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||||
|
auto kernels = kernel_graph->execution_order();
|
||||||
|
size_t conv_cnt = 0;
|
||||||
|
size_t bn_cnt = 0;
|
||||||
|
for (const auto &kernel : kernels) {
|
||||||
|
auto kernel_name = AnfAlgo::GetCNodeName(kernel);
|
||||||
|
if (kernel_name == prim::kPrimLayerNorm->name()) {
|
||||||
|
format_transform_ = false;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (kernel_name == prim::kPrimConv2D->name()) {
|
||||||
|
conv_cnt++;
|
||||||
|
}
|
||||||
|
if (kernel_name == prim::kPrimFusedBatchNormEx->name()) {
|
||||||
|
bn_cnt++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (conv_cnt == kConv2dCount && bn_cnt == kFusedBatchNormCount) {
|
||||||
|
format_transform_ = false;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
format_transform_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetKernelInfo(const CNodePtr &kernel_node) {
|
||||||
std::vector<std::string> inputs_format;
|
std::vector<std::string> inputs_format;
|
||||||
std::vector<TypeId> inputs_type;
|
std::vector<TypeId> inputs_type;
|
||||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
||||||
|
@ -246,7 +273,7 @@ void SetKernelInfo(const CNodePtr &kernel_node, bool graph_format_transform) {
|
||||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
|
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
|
||||||
}
|
}
|
||||||
std::string origin_data_format = kOpFormat_DEFAULT;
|
std::string origin_data_format = kOpFormat_DEFAULT;
|
||||||
if (graph_format_transform && IsNeedProcessFormatInfo(kernel_node, inputs_type)) {
|
if (IsNeedProcessFormatInfo(kernel_node, inputs_type)) {
|
||||||
UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format);
|
UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format);
|
||||||
}
|
}
|
||||||
std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder =
|
std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder =
|
||||||
|
|
|
@ -20,11 +20,13 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "ir/dtype.h"
|
#include "ir/dtype.h"
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
#include "frontend/operator/ops.h"
|
#include "frontend/operator/ops.h"
|
||||||
|
#include "backend/session/kernel_graph.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
|
@ -53,7 +55,28 @@ static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>
|
||||||
{prim::kPrimAddN->name(), {{}, {0}}},
|
{prim::kPrimAddN->name(), {{}, {0}}},
|
||||||
};
|
};
|
||||||
|
|
||||||
void SetKernelInfo(const CNodePtr &kernel_node, bool graph_format_transform = false);
|
void SetKernelInfo(const CNodePtr &kernel_node);
|
||||||
|
|
||||||
|
class FormatTransformChecker {
|
||||||
|
public:
|
||||||
|
void CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||||
|
bool format_transform() const { return format_transform_; }
|
||||||
|
|
||||||
|
static FormatTransformChecker &GetInstance() {
|
||||||
|
static FormatTransformChecker instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
FormatTransformChecker() = default;
|
||||||
|
~FormatTransformChecker() = default;
|
||||||
|
FormatTransformChecker(const FormatTransformChecker &);
|
||||||
|
FormatTransformChecker &operator=(const FormatTransformChecker &);
|
||||||
|
|
||||||
|
bool format_transform_{true};
|
||||||
|
static constexpr size_t kConv2dCount = 96;
|
||||||
|
static constexpr size_t kFusedBatchNormCount = 94;
|
||||||
|
};
|
||||||
|
|
||||||
class KernelAttr {
|
class KernelAttr {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -133,6 +133,10 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/
|
||||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/scheduler.cc")
|
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/scheduler.cc")
|
||||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc")
|
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc")
|
||||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc")
|
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc")
|
||||||
|
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc")
|
||||||
|
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc")
|
||||||
|
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc")
|
||||||
|
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc")
|
||||||
|
|
||||||
add_library(_ut_mindspore_obj OBJECT ${MINDSPORE_SRC_LIST})
|
add_library(_ut_mindspore_obj OBJECT ${MINDSPORE_SRC_LIST})
|
||||||
add_library(_ut_ut_obj OBJECT ${UT_SRCS})
|
add_library(_ut_ut_obj OBJECT ${UT_SRCS})
|
||||||
|
|
Loading…
Reference in New Issue