!40596 ub fusion check graph kernel supported ops
Merge pull request !40596 from looop5/ub_fusion_check
This commit is contained in:
commit
ff1b4f124e
|
@ -26,7 +26,7 @@
|
|||
#include "common/graph_kernel/graph_kernel_flags.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||
namespace mindspore::graphkernel {
|
||||
std::vector<PrimitivePtr> GraphKernelExpanderWithPy::InitOpList() {
|
||||
std::vector<PrimitivePtr> GraphKernelExpanderWithPy::GetExpanderOps() {
|
||||
std::vector<OpWithLevel> expand_ops_with_level = {
|
||||
{kAllTarget, OpLevel_0, prim::kPrimAddN},
|
||||
{kAllTarget, OpLevel_0, prim::kPrimAssignAdd},
|
||||
|
@ -91,5 +91,9 @@ std::vector<PrimitivePtr> GraphKernelExpanderWithPy::InitOpList() {
|
|||
return GkUtils::FilterExcludedOps(ops);
|
||||
}
|
||||
|
||||
std::vector<PrimitivePtr> GraphKernelExpanderWithPy::InitOpList() {
|
||||
return GraphKernelExpanderWithPy::GetExpanderOps();
|
||||
}
|
||||
|
||||
ExpanderPtr GraphKernelExpanderWithPy::InitExpander(const AnfNodePtr &node) { return GetExpander(node, false); }
|
||||
} // namespace mindspore::graphkernel
|
||||
|
|
|
@ -29,6 +29,7 @@ class GraphKernelExpanderWithPy : public GraphKernelExpander {
|
|||
GraphKernelExpanderWithPy() : GraphKernelExpander() {}
|
||||
explicit GraphKernelExpanderWithPy(const std::string &name) : GraphKernelExpander(name) {}
|
||||
~GraphKernelExpanderWithPy() override = default;
|
||||
static std::vector<PrimitivePtr> GetExpanderOps();
|
||||
|
||||
protected:
|
||||
std::vector<PrimitivePtr> InitOpList() override;
|
||||
|
|
|
@ -273,4 +273,20 @@ void GraphKernelOptimize(const KernelGraphPtr &kernel_graph) {
|
|||
GraphKernelOptimizer graph_kernel_optimizer;
|
||||
graph_kernel_optimizer.Run(kernel_graph);
|
||||
}
|
||||
|
||||
bool GraphKernelSupported(const std::vector<AnfNodePtr> &nodes) {
|
||||
static std::vector<PrimitivePtr> supported_nodes;
|
||||
if (supported_nodes.empty()) {
|
||||
supported_nodes = GraphKernelExpanderWithPy::GetExpanderOps();
|
||||
auto cluster_nodes = GraphKernelCluster::GetClusterOps();
|
||||
(void)std::copy(cluster_nodes.begin(), cluster_nodes.end(), std::back_inserter(supported_nodes));
|
||||
}
|
||||
for (const auto &node : nodes) {
|
||||
if (node != nullptr && !std::any_of(supported_nodes.begin(), supported_nodes.end(),
|
||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADAPTER_GRAPH_KERNEL_OPTIMIZATION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADAPTER_GRAPH_KERNEL_OPTIMIZATION_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/optimizer/pass_manager.h"
|
||||
|
@ -50,5 +51,6 @@ class GraphKernelOptimizer {
|
|||
};
|
||||
|
||||
void GraphKernelOptimize(const KernelGraphPtr &kernel_graph);
|
||||
bool GraphKernelSupported(const std::vector<AnfNodePtr> &nodes);
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADAPTER_GRAPH_KERNEL_OPTIMIZATION_H_
|
||||
|
|
|
@ -33,7 +33,7 @@
|
|||
#include "common/graph_kernel/core/graph_builder.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
std::vector<PrimitivePtr> GraphKernelCluster::GetClusterableOpList() {
|
||||
std::vector<PrimitivePtr> GraphKernelCluster::GetClusterOps() {
|
||||
std::vector<OpWithLevel> clusterable_ops_with_level = {
|
||||
// all target
|
||||
{kAllTarget, OpLevel_0, prim::kPrimAbs},
|
||||
|
@ -113,6 +113,8 @@ std::vector<PrimitivePtr> GraphKernelCluster::GetClusterableOpList() {
|
|||
return GkUtils::FilterExcludedOps(ops);
|
||||
}
|
||||
|
||||
std::vector<PrimitivePtr> GraphKernelCluster::GetClusterableOpList() { return GraphKernelCluster::GetClusterOps(); }
|
||||
|
||||
bool GraphKernelCluster::IsClusterableOp(const AnfNodePtr &node) {
|
||||
if (AnfUtils::IsGraphKernel(node)) {
|
||||
auto sub_graph = GetCNodeFuncGraph(node);
|
||||
|
|
|
@ -33,6 +33,7 @@ class GraphKernelCluster : public opt::Pass {
|
|||
GraphKernelCluster() : Pass("graph_kernel_cluster") {}
|
||||
~GraphKernelCluster() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
static std::vector<PrimitivePtr> GetClusterOps();
|
||||
|
||||
protected:
|
||||
virtual std::vector<PrimitivePtr> GetClusterableOpList();
|
||||
|
|
|
@ -173,7 +173,6 @@
|
|||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_compile.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "include/common/utils/config_manager.h"
|
||||
#include "common/graph_kernel/graph_kernel_flags.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
#include "include/common/debug/dump_proto.h"
|
||||
#include "include/common/debug/draw.h"
|
||||
|
@ -611,14 +610,10 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGrap
|
|||
ub_fusion_pm->AddPass(std::make_shared<ReduceEltwiseFusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<SegmentEltwiseFusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<MultiOutputFusionPass>(fusion_id_allocator));
|
||||
if (!graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||
ub_fusion_pm->AddPass(std::make_shared<EltwiseFusionPass>(fusion_id_allocator));
|
||||
}
|
||||
ub_fusion_pm->AddPass(std::make_shared<EltwiseFusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<DepthwiseConvEltwiseFusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<MatmulConfusionTranposeFusionPass>(fusion_id_allocator));
|
||||
if (!graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||
ub_fusion_pm->AddPass(std::make_shared<BatchMatmulEltwiseFusionPass>(fusion_id_allocator));
|
||||
}
|
||||
ub_fusion_pm->AddPass(std::make_shared<BatchMatmulEltwiseFusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<BatchMatmulDropoutDoMaskV3FusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<UbPatternFusion>());
|
||||
optimizer->AddPassManager(ub_fusion_pm);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -18,7 +18,6 @@
|
|||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
#include "common/graph_kernel/graph_kernel_flags.h"
|
||||
#include "backend/common/optimizer/fusion_id_allocator.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -46,14 +45,6 @@ void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap
|
|||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&
|
||||
common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimAddN)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -27,6 +27,10 @@
|
|||
#include "plugin/device/ascend/kernel/tbe/tbe_utils.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#ifdef ENABLE_AKG
|
||||
#include "common/graph_kernel/graph_kernel_flags.h"
|
||||
#include "common/graph_kernel/adapter/graph_kernel_optimization.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -218,6 +222,21 @@ void GetFusionScopeComputeNodeList(const session::KernelGraph *kernel_graph,
|
|||
(*buffer_fusion_infos)[fusion_id].anf_nodes.push_back(cnode);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ENABLE_AKG
|
||||
// If Graph Kernel Fusion is enabled, we will let Graph Kernel fuse these nodes if it supports.
|
||||
if (graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||
auto iter = buffer_fusion_infos->begin();
|
||||
while (iter != buffer_fusion_infos->end()) {
|
||||
if (graphkernel::GraphKernelSupported(iter->second.anf_nodes)) {
|
||||
MS_LOG(DEBUG) << "Fusion id: " << iter->first << ", uses Graph Kernel Fusion";
|
||||
buffer_fusion_infos->erase(iter++);
|
||||
} else {
|
||||
iter++;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph,
|
||||
|
|
Loading…
Reference in New Issue