!40596 ub fusion check graph kernel supported ops

Merge pull request !40596 from looop5/ub_fusion_check
This commit is contained in:
i-robot 2022-08-23 02:57:37 +00:00 committed by Gitee
commit ff1b4f124e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 51 additions and 20 deletions

View File

@ -26,7 +26,7 @@
#include "common/graph_kernel/graph_kernel_flags.h"
#include "common/graph_kernel/core/graph_kernel_utils.h"
namespace mindspore::graphkernel {
std::vector<PrimitivePtr> GraphKernelExpanderWithPy::InitOpList() {
std::vector<PrimitivePtr> GraphKernelExpanderWithPy::GetExpanderOps() {
std::vector<OpWithLevel> expand_ops_with_level = {
{kAllTarget, OpLevel_0, prim::kPrimAddN},
{kAllTarget, OpLevel_0, prim::kPrimAssignAdd},
@ -91,5 +91,9 @@ std::vector<PrimitivePtr> GraphKernelExpanderWithPy::InitOpList() {
return GkUtils::FilterExcludedOps(ops);
}
std::vector<PrimitivePtr> GraphKernelExpanderWithPy::InitOpList() {
return GraphKernelExpanderWithPy::GetExpanderOps();
}
ExpanderPtr GraphKernelExpanderWithPy::InitExpander(const AnfNodePtr &node) { return GetExpander(node, false); }
} // namespace mindspore::graphkernel

View File

@ -29,6 +29,7 @@ class GraphKernelExpanderWithPy : public GraphKernelExpander {
GraphKernelExpanderWithPy() : GraphKernelExpander() {}
explicit GraphKernelExpanderWithPy(const std::string &name) : GraphKernelExpander(name) {}
~GraphKernelExpanderWithPy() override = default;
static std::vector<PrimitivePtr> GetExpanderOps();
protected:
std::vector<PrimitivePtr> InitOpList() override;

View File

@ -273,4 +273,20 @@ void GraphKernelOptimize(const KernelGraphPtr &kernel_graph) {
GraphKernelOptimizer graph_kernel_optimizer;
graph_kernel_optimizer.Run(kernel_graph);
}
bool GraphKernelSupported(const std::vector<AnfNodePtr> &nodes) {
static std::vector<PrimitivePtr> supported_nodes;
if (supported_nodes.empty()) {
supported_nodes = GraphKernelExpanderWithPy::GetExpanderOps();
auto cluster_nodes = GraphKernelCluster::GetClusterOps();
(void)std::copy(cluster_nodes.begin(), cluster_nodes.end(), std::back_inserter(supported_nodes));
}
for (const auto &node : nodes) {
if (node != nullptr && !std::any_of(supported_nodes.begin(), supported_nodes.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) {
return false;
}
}
return true;
}
} // namespace mindspore::graphkernel

View File

@ -16,6 +16,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADAPTER_GRAPH_KERNEL_OPTIMIZATION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADAPTER_GRAPH_KERNEL_OPTIMIZATION_H_
#include <vector>
#include "backend/common/session/kernel_graph.h"
#include "backend/common/optimizer/optimizer.h"
#include "backend/common/optimizer/pass_manager.h"
@ -50,5 +51,6 @@ class GraphKernelOptimizer {
};
void GraphKernelOptimize(const KernelGraphPtr &kernel_graph);
bool GraphKernelSupported(const std::vector<AnfNodePtr> &nodes);
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADAPTER_GRAPH_KERNEL_OPTIMIZATION_H_

View File

@ -33,7 +33,7 @@
#include "common/graph_kernel/core/graph_builder.h"
namespace mindspore::graphkernel {
std::vector<PrimitivePtr> GraphKernelCluster::GetClusterableOpList() {
std::vector<PrimitivePtr> GraphKernelCluster::GetClusterOps() {
std::vector<OpWithLevel> clusterable_ops_with_level = {
// all target
{kAllTarget, OpLevel_0, prim::kPrimAbs},
@ -113,6 +113,8 @@ std::vector<PrimitivePtr> GraphKernelCluster::GetClusterableOpList() {
return GkUtils::FilterExcludedOps(ops);
}
std::vector<PrimitivePtr> GraphKernelCluster::GetClusterableOpList() { return GraphKernelCluster::GetClusterOps(); }
bool GraphKernelCluster::IsClusterableOp(const AnfNodePtr &node) {
if (AnfUtils::IsGraphKernel(node)) {
auto sub_graph = GetCNodeFuncGraph(node);

View File

@ -33,6 +33,7 @@ class GraphKernelCluster : public opt::Pass {
GraphKernelCluster() : Pass("graph_kernel_cluster") {}
~GraphKernelCluster() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
static std::vector<PrimitivePtr> GetClusterOps();
protected:
virtual std::vector<PrimitivePtr> GetClusterableOpList();

View File

@ -173,7 +173,6 @@
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_compile.h"
#include "utils/ms_context.h"
#include "include/common/utils/config_manager.h"
#include "common/graph_kernel/graph_kernel_flags.h"
#include "include/common/debug/anf_ir_dump.h"
#include "include/common/debug/dump_proto.h"
#include "include/common/debug/draw.h"
@ -611,14 +610,10 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGrap
ub_fusion_pm->AddPass(std::make_shared<ReduceEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<SegmentEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<MultiOutputFusionPass>(fusion_id_allocator));
if (!graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
ub_fusion_pm->AddPass(std::make_shared<EltwiseFusionPass>(fusion_id_allocator));
}
ub_fusion_pm->AddPass(std::make_shared<EltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<DepthwiseConvEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<MatmulConfusionTranposeFusionPass>(fusion_id_allocator));
if (!graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
ub_fusion_pm->AddPass(std::make_shared<BatchMatmulEltwiseFusionPass>(fusion_id_allocator));
}
ub_fusion_pm->AddPass(std::make_shared<BatchMatmulEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<BatchMatmulDropoutDoMaskV3FusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<UbPatternFusion>());
optimizer->AddPassManager(ub_fusion_pm);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,7 +18,6 @@
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "mindspore/core/ops/core_ops.h"
#include "common/graph_kernel/graph_kernel_flags.h"
#include "backend/common/optimizer/fusion_id_allocator.h"
namespace mindspore {
@ -46,14 +45,6 @@ void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap
continue;
}
auto cnode = node->cast<CNodePtr>();
if (graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&
common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimAddN)) {
continue;
}
}
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,6 +27,10 @@
#include "plugin/device/ascend/kernel/tbe/tbe_utils.h"
#include "include/common/debug/anf_ir_dump.h"
#include "backend/common/optimizer/helper.h"
#ifdef ENABLE_AKG
#include "common/graph_kernel/graph_kernel_flags.h"
#include "common/graph_kernel/adapter/graph_kernel_optimization.h"
#endif
namespace mindspore {
namespace opt {
@ -218,6 +222,21 @@ void GetFusionScopeComputeNodeList(const session::KernelGraph *kernel_graph,
(*buffer_fusion_infos)[fusion_id].anf_nodes.push_back(cnode);
}
}
#ifdef ENABLE_AKG
// If Graph Kernel Fusion is enabled, we will let Graph Kernel fuse these nodes if it supports.
if (graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
auto iter = buffer_fusion_infos->begin();
while (iter != buffer_fusion_infos->end()) {
if (graphkernel::GraphKernelSupported(iter->second.anf_nodes)) {
MS_LOG(DEBUG) << "Fusion id: " << iter->first << ", uses Graph Kernel Fusion";
buffer_fusion_infos->erase(iter++);
} else {
iter++;
}
}
}
#endif
}
void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph,