forked from mindspore-Ecosystem/mindspore
!5481 add mode black list checker
Merge pull request !5481 from zyli2020/master
This commit is contained in:
commit
18253952f5
|
@ -49,9 +49,10 @@ using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
|
|||
|
||||
void GPUSession::SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
bool in_black_list = CheckInModeBlackList(kernel_graph);
|
||||
for (const auto &kernel_node : kernel_graph->execution_order()) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
device::gpu::SetKernelInfo(kernel_node);
|
||||
device::gpu::SetKernelInfo(kernel_node, in_black_list);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -75,7 +76,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
|||
pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
|
||||
if (context_ptr->execution_mode() != kPynativeMode) {
|
||||
if (!CheckInModeBlackList(kernel_graph) && context_ptr->execution_mode() != kPynativeMode) {
|
||||
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
|
||||
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
|
||||
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
|
||||
|
@ -192,6 +193,28 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const
|
|||
}
|
||||
}
|
||||
|
||||
bool GPUSession::CheckInModeBlackList(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
auto kernels = kernel_graph->execution_order();
|
||||
size_t conv_cnt = 0;
|
||||
size_t bn_cnt = 0;
|
||||
for (const auto &kernel : kernels) {
|
||||
auto kernel_name = AnfAlgo::GetCNodeName(kernel);
|
||||
if (kernel_name == prim::kPrimLayerNorm->name()) {
|
||||
return true;
|
||||
}
|
||||
if (kernel_name == prim::kPrimConv2D->name()) {
|
||||
conv_cnt++;
|
||||
}
|
||||
if (kernel_name == prim::kPrimFusedBatchNormEx->name()) {
|
||||
bn_cnt++;
|
||||
}
|
||||
}
|
||||
if (conv_cnt == kConv2dCount && bn_cnt == kFusedBatchNormCount) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
|
||||
// Construct graph, if successfully, graph_sum_ + 1
|
||||
auto graph_id = graph_sum_;
|
||||
|
|
|
@ -67,6 +67,8 @@ class GPUSession : public SessionBasic {
|
|||
|
||||
void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
||||
bool CheckInModeBlackList(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
||||
|
@ -80,6 +82,9 @@ class GPUSession : public SessionBasic {
|
|||
|
||||
void PostLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
#endif
|
||||
|
||||
static constexpr size_t kConv2dCount = 96;
|
||||
static constexpr size_t kFusedBatchNormCount = 94;
|
||||
};
|
||||
using GPUSessionPtr = std::shared_ptr<GPUSession>;
|
||||
MS_REG_SESSION(kGPUDevice, GPUSession);
|
||||
|
|
|
@ -223,7 +223,7 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeI
|
|||
}
|
||||
} // namespace
|
||||
|
||||
void SetKernelInfo(const CNodePtr &kernel_node) {
|
||||
void SetKernelInfo(const CNodePtr &kernel_node, bool in_black_list) {
|
||||
std::vector<std::string> inputs_format;
|
||||
std::vector<TypeId> inputs_type;
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
||||
|
@ -237,7 +237,7 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
|||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
|
||||
}
|
||||
std::string origin_data_format = kOpFormat_DEFAULT;
|
||||
if (IsNeedProcessFormatInfo(kernel_node, inputs_type)) {
|
||||
if (!in_black_list && IsNeedProcessFormatInfo(kernel_node, inputs_type)) {
|
||||
UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format);
|
||||
}
|
||||
std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder =
|
||||
|
|
|
@ -53,7 +53,7 @@ static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>
|
|||
{prim::kPrimAddN->name(), {{}, {0}}},
|
||||
};
|
||||
|
||||
void SetKernelInfo(const CNodePtr &apply_kernel_ptr);
|
||||
void SetKernelInfo(const CNodePtr &kernel_node, bool in_black_list = false);
|
||||
|
||||
class KernelAttr {
|
||||
public:
|
||||
|
|
Loading…
Reference in New Issue