forked from mindspore-Ecosystem/mindspore
add graphkerneloptimize pass
align fuse_ops_fusion align composite_ops_fusion unify ops table Init new_code's kernel_info with orig_node's kernel_info in function NewCNodeWithInfo enable run bert add pass tensor_promotion add macro for bias_add and bias_add_grad in expander pass exclude unused attrs in primitive compare for GraphKernelCSE exclude fusion_type in kernelinfo cmp for cse in graphkernel check processor remove graph kernel pass before select kernel recover run_standalone_pretrain_ascend.sh remove is_before_kernel_select move add_atomic_clean from pass directory to graph_kernel directory update fuse op list in Ascend back-end
This commit is contained in:
parent
c54a4a4494
commit
a51465c78b
|
@ -74,7 +74,6 @@
|
|||
#include "backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.h"
|
||||
#include "backend/optimizer/pass/eliminate_redundant_op.h"
|
||||
#include "backend/optimizer/pass/common_subexpression_elimination.h"
|
||||
#include "backend/optimizer/pass/add_atomic_clean.h"
|
||||
#include "backend/optimizer/ascend/format_type/merge_cast_to_op.h"
|
||||
#include "backend/optimizer/ascend/format_type/check_consistency.h"
|
||||
#include "backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h"
|
||||
|
@ -382,74 +381,6 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
|||
}
|
||||
}
|
||||
|
||||
void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||
bool is_before_kernel_select) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
|
||||
return;
|
||||
}
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
if (save_graphs) {
|
||||
std::string file_name = "hwopt_d_graph_kernel_opt_before_graph_" + std::to_string(!is_before_kernel_select) + "_" +
|
||||
std::to_string(kernel_graph->graph_id()) + ".ir";
|
||||
DumpIR(file_name, kernel_graph);
|
||||
}
|
||||
|
||||
// Fuse graph kernels with basic ops
|
||||
static_cast<void>(FuseCompositeOps(kernel_graph, is_before_kernel_select));
|
||||
|
||||
if (save_graphs) {
|
||||
std::string file_name = "hwopt_d_graph_kernel_opt_end_graph_" + std::to_string(!is_before_kernel_select) + "_" +
|
||||
std::to_string(kernel_graph->graph_id()) + ".ir";
|
||||
DumpIR(file_name, kernel_graph, true);
|
||||
}
|
||||
}
|
||||
|
||||
void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||
bool is_before_kernel_select) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
|
||||
return;
|
||||
}
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
if (save_graphs) {
|
||||
std::string file_name = "hwopt_fuse_basic_opt_before_graph_" + std::to_string(!is_before_kernel_select) + "_" +
|
||||
std::to_string(kernel_graph->graph_id()) + ".ir";
|
||||
DumpIR(file_name, kernel_graph, true);
|
||||
}
|
||||
|
||||
// Fuse basic ops with basic ops
|
||||
static_cast<void>(FuseBasicOps(kernel_graph, is_before_kernel_select));
|
||||
|
||||
if (save_graphs) {
|
||||
std::string file_name = "hwopt_fuse_basic_opt_end_graph_" + std::to_string(!is_before_kernel_select) + "_" +
|
||||
std::to_string(kernel_graph->graph_id()) + ".ir";
|
||||
DumpIR(file_name, kernel_graph, true);
|
||||
}
|
||||
}
|
||||
|
||||
void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
|
||||
return;
|
||||
}
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
if (save_graphs) {
|
||||
std::string file_name = "hwopt_d_add_atomic_clean_before_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
|
||||
DumpIR(file_name, kernel_graph);
|
||||
}
|
||||
|
||||
AddAtomicClean(kernel_graph);
|
||||
|
||||
if (save_graphs) {
|
||||
std::string file_name = "hwopt_d_end_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
|
||||
DumpIR(file_name, kernel_graph, true);
|
||||
}
|
||||
}
|
||||
|
||||
void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
|
|
|
@ -25,11 +25,6 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph)
|
|||
void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||
bool is_before_kernel_select = false);
|
||||
void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||
bool is_before_kernel_select = false);
|
||||
void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
} // namespace opt
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/pass/add_atomic_clean.h"
|
||||
#include "backend/optimizer/graph_kernel/add_atomic_clean.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
|
@ -75,7 +75,7 @@ CNodePtr CreateTbeAtomicCleanNode(const std::shared_ptr<session::KernelGraph> &k
|
|||
}
|
||||
} // namespace
|
||||
|
||||
void AddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
bool AddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto mng = kernel_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
|
@ -83,6 +83,7 @@ void AddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
|||
kernel_graph->set_manager(mng);
|
||||
}
|
||||
auto &todos = kernel_graph->execution_order();
|
||||
bool changed = false;
|
||||
for (auto iter = todos.cbegin(); iter != todos.end(); ++iter) {
|
||||
auto node = *iter;
|
||||
if (AnfAlgo::IsGraphKernel(node) && kernel_graph->nodes().contains(node)) {
|
||||
|
@ -112,9 +113,17 @@ void AddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
|||
new_cnode->set_kernel_info(node->kernel_info_ptr());
|
||||
mng->Replace(node, new_cnode);
|
||||
g_output_idx.clear();
|
||||
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool CleanAddAtomic::Run(const FuncGraphPtr &func_graph) {
|
||||
return AddAtomicClean(std::dynamic_pointer_cast<session::KernelGraph>(func_graph));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -14,16 +14,23 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_ATOMIC_CLEAN_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_ATOMIC_CLEAN_H_
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CLEAN_ADD_ATOMIC_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CLEAN_ADD_ATOMIC_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
void AddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
class CleanAddAtomic : public Pass {
|
||||
public:
|
||||
CleanAddAtomic() : Pass("clean_add_atomic") {}
|
||||
~CleanAddAtomic() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
};
|
||||
using CleanAddAtomicPtr = std::shared_ptr<CleanAddAtomic>;
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_ATOMIC_CLEAN_H
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CLEAN_ADD_ATOMIC_H_
|
|
@ -35,24 +35,7 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
bool IsBasicOp(const AnfNodePtr &node, bool is_before_kernel_select) {
|
||||
#if ENABLE_D
|
||||
std::vector<PrimitivePtr> fusible_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub,
|
||||
prim::kPrimExpandDims};
|
||||
if (!is_before_kernel_select) {
|
||||
fusible_basic_ops.push_back(prim::kPrimCast);
|
||||
}
|
||||
#elif ENABLE_GPU
|
||||
std::vector<PrimitivePtr> fusible_basic_ops = GetFusibleOpList();
|
||||
#else
|
||||
std::vector<PrimitivePtr> fusible_basic_ops;
|
||||
#endif
|
||||
return std::any_of(fusible_basic_ops.begin(), fusible_basic_ops.end(),
|
||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||
}
|
||||
|
||||
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKernelInfo &info,
|
||||
const AnfNodePtr &node) {
|
||||
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) {
|
||||
if (cur_node == node) {
|
||||
return FOLLOW;
|
||||
}
|
||||
|
@ -60,16 +43,13 @@ IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKe
|
|||
return EXCLUDE;
|
||||
}
|
||||
|
||||
bool is_fusable = IsBasicOp(node, info.is_before_kernel_select);
|
||||
bool is_fusable = IsBasicFuseOp(node);
|
||||
return is_fusable ? FOLLOW : EXCLUDE;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) {
|
||||
GraphKernelInfo info;
|
||||
info.is_before_kernel_select = is_before_kernel_select;
|
||||
|
||||
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) {
|
||||
// Search fusable nodes according input direction.
|
||||
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, info, std::placeholders::_1);
|
||||
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1);
|
||||
auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
|
||||
if (used_nodes.size() > 1) {
|
||||
used_nodes = RemoveCircle(used_nodes, false);
|
||||
|
@ -176,7 +156,7 @@ void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, con
|
|||
}
|
||||
|
||||
bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr> &todos,
|
||||
std::unordered_set<AnfNodePtr> *fused_ops, bool is_before_kernel_select) {
|
||||
std::unordered_set<AnfNodePtr> *fused_ops) {
|
||||
bool changed = false;
|
||||
auto mng = kernel_graph->manager();
|
||||
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
|
||||
|
@ -187,12 +167,12 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
|
|||
if (fused_ops->count(node)) {
|
||||
continue;
|
||||
}
|
||||
bool is_basic_op = IsBasicOp(node, is_before_kernel_select);
|
||||
bool is_basic_op = IsBasicFuseOp(node);
|
||||
if (!is_basic_op || !kernel_graph->nodes().contains(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto fuse_nodes = FindFuseCNodes(node, is_before_kernel_select);
|
||||
auto fuse_nodes = FindFuseCNodes(node);
|
||||
if (fuse_nodes.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
|
@ -204,10 +184,8 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
|
|||
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes);
|
||||
RemoveControlDependOut(fg, &outputs, mng);
|
||||
ConvertNonscalarTensorToParameter(fg, &inputs);
|
||||
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs, is_before_kernel_select);
|
||||
if (!is_before_kernel_select) {
|
||||
SetNewKernelInfo(fuse_new_node, fg, inputs, outputs, AnfAlgo::GetProcessor(fuse_nodes[0]));
|
||||
}
|
||||
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs);
|
||||
SetNewKernelInfo(fuse_new_node, fg, inputs, outputs, AnfAlgo::GetProcessor(fuse_nodes[0]));
|
||||
|
||||
ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs);
|
||||
|
||||
|
@ -224,7 +202,7 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
|
|||
}
|
||||
} // namespace
|
||||
|
||||
bool FuseBasicOps(const FuncGraphPtr &kernel_graph, bool is_before_kernel_select) {
|
||||
bool FuseBasicOps(const FuncGraphPtr &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto mng = kernel_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
|
@ -234,9 +212,9 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, bool is_before_kernel_select
|
|||
std::unordered_set<AnfNodePtr> fused_ops;
|
||||
auto todos = TopoSort(kernel_graph->get_return());
|
||||
std::reverse(todos.begin(), todos.end());
|
||||
return FuseBasicOps(kernel_graph, todos, &fused_ops, is_before_kernel_select);
|
||||
return FuseBasicOps(kernel_graph, todos, &fused_ops);
|
||||
}
|
||||
|
||||
bool BasicOpsFusion::Run(const FuncGraphPtr &func_graph) { return FuseBasicOps(func_graph, false); }
|
||||
bool BasicOpsFusion::Run(const FuncGraphPtr &func_graph) { return FuseBasicOps(func_graph); }
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
bool FuseBasicOps(const FuncGraphPtr &kernel_graph, bool is_before_kernel_select);
|
||||
bool FuseBasicOps(const FuncGraphPtr &kernel_graph);
|
||||
|
||||
class BasicOpsFusion : public Pass {
|
||||
public:
|
||||
|
|
|
@ -60,127 +60,23 @@ std::vector<AnfNodePtr> DeepUsersSearch(const std::vector<AnfNodePtr> &roots, co
|
|||
}
|
||||
} // namespace
|
||||
|
||||
bool IsBasicFuseOp(const AnfNodePtr &node, bool is_before_kernel_select) {
|
||||
#if ENABLE_D
|
||||
std::vector<PrimitivePtr> basic_ops = {
|
||||
prim::kPrimAddN, prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, prim::kPrimMaximum,
|
||||
prim::kPrimMinimum, prim::kPrimNeg, prim::kPrimRealDiv, prim::kPrimPow, prim::kPrimSqrt,
|
||||
prim::kPrimExpandDims, prim::kPrimReciprocal, prim::kPrimLessEqual};
|
||||
if (!is_before_kernel_select) {
|
||||
basic_ops.push_back(prim::kPrimCast);
|
||||
}
|
||||
#elif ENABLE_GPU
|
||||
std::vector<PrimitivePtr> basic_ops = GetFusibleOpList();
|
||||
#else
|
||||
std::vector<PrimitivePtr> basic_ops;
|
||||
#endif
|
||||
return std::any_of(basic_ops.begin(), basic_ops.end(),
|
||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||
}
|
||||
|
||||
bool IsReduceOp(const AnfNodePtr &node) {
|
||||
std::vector<PrimitivePtr> reduce_ops = {prim::kPrimReduceSum, prim::kPrimReduceMean, prim::kPrimReduceMin,
|
||||
prim::kPrimReduceMax, prim::kPrimReduceAll};
|
||||
return std::any_of(reduce_ops.begin(), reduce_ops.end(),
|
||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||
}
|
||||
|
||||
void GetGraphKernelInfo(const FuncGraphPtr &fg, GraphKernelInfo *info) {
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto mng = fg->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(fg, false);
|
||||
fg->set_manager(mng);
|
||||
}
|
||||
const auto &nodes = fg->nodes();
|
||||
info->op_type = ELEWISE;
|
||||
info->cal_step = -1;
|
||||
info->reduce_op_num = 0;
|
||||
for (auto node : nodes) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
info->cal_step++;
|
||||
if (IsReduceOp(node)) {
|
||||
info->op_type = REDUCE;
|
||||
info->reduce_op_num++;
|
||||
}
|
||||
}
|
||||
auto fg_flag = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||
if (fg_flag != nullptr) {
|
||||
auto fg_name = GetValue<std::string>(fg_flag);
|
||||
info->origin_composite_name = fg_name;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsCompositeFuseBasic(const GraphKernelInfo &info, const AnfNodePtr &node) {
|
||||
#if ENABLE_D
|
||||
std::vector<PrimitivePtr> fusable_with_reduce;
|
||||
if (!info.is_before_kernel_select) {
|
||||
fusable_with_reduce.push_back(prim::kPrimCast);
|
||||
}
|
||||
if (info.op_type == REDUCE &&
|
||||
(info.cal_step >= MAX_REDUCE_OP_FUSION_CAL_STEP || info.reduce_op_num >= MAX_REDUCE_OP_FUSION_REDUCE_NUM)) {
|
||||
return std::any_of(fusable_with_reduce.begin(), fusable_with_reduce.end(),
|
||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||
}
|
||||
#endif
|
||||
return IsBasicFuseOp(node, info.is_before_kernel_select);
|
||||
}
|
||||
|
||||
bool IsFuse(const GraphKernelInfo &info, const AnfNodePtr &node) {
|
||||
bool IsFuse(const AnfNodePtr &node) {
|
||||
// composite fuse composite op
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
#if ENABLE_D
|
||||
return false;
|
||||
#else
|
||||
return true;
|
||||
#endif
|
||||
}
|
||||
return IsCompositeFuseBasic(info, node);
|
||||
return IsBasicFuseOp(node);
|
||||
}
|
||||
|
||||
void UpdateGraphKernelInfo(GraphKernelInfo *info, const AnfNodePtr &node) {
|
||||
if (IsPrimitiveCNode(node)) {
|
||||
info->cal_step++;
|
||||
if (IsReduceOp(node)) {
|
||||
info->op_type = REDUCE;
|
||||
}
|
||||
info->origin_composite_name += AnfAlgo::GetCNodePrimitive(node)->name() + "_";
|
||||
} else if (AnfAlgo::IsGraphKernel(node)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto composite_g = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
||||
GraphKernelInfo fuse_info;
|
||||
GetGraphKernelInfo(composite_g, &fuse_info);
|
||||
info->cal_step += fuse_info.cal_step;
|
||||
info->origin_composite_name += fuse_info.origin_composite_name;
|
||||
}
|
||||
}
|
||||
|
||||
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, GraphKernelInfo *info, const AnfNodePtr &node) {
|
||||
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) {
|
||||
if (cur_node == node) {
|
||||
return FOLLOW;
|
||||
}
|
||||
#if ENABLE_D
|
||||
if (!IsPrimitiveCNode(node)) {
|
||||
return EXCLUDE;
|
||||
}
|
||||
#else
|
||||
bool is_fuse_composite = AnfAlgo::IsGraphKernel(node);
|
||||
if (!IsPrimitiveCNode(node) && !is_fuse_composite) {
|
||||
return EXCLUDE;
|
||||
}
|
||||
#endif
|
||||
|
||||
bool is_fusable = IsFuse(*info, node);
|
||||
if (is_fusable) {
|
||||
UpdateGraphKernelInfo(info, node);
|
||||
}
|
||||
bool is_fusable = IsFuse(node);
|
||||
return is_fusable ? FOLLOW : EXCLUDE;
|
||||
}
|
||||
|
||||
IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, GraphKernelInfo *info, const AnfNodePtr &node) {
|
||||
IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const AnfNodePtr &node) {
|
||||
if (cur_node == node) {
|
||||
return FOLLOW;
|
||||
}
|
||||
|
@ -195,14 +91,7 @@ IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, GraphKernelI
|
|||
}
|
||||
return EXCLUDE;
|
||||
}
|
||||
if (!IsPrimitiveCNode(node)) {
|
||||
return EXCLUDE;
|
||||
}
|
||||
|
||||
bool is_fusable = IsFuse(*info, node);
|
||||
if (is_fusable) {
|
||||
UpdateGraphKernelInfo(info, node);
|
||||
}
|
||||
bool is_fusable = IsBasicFuseOp(node);
|
||||
return is_fusable ? FOLLOW : EXCLUDE;
|
||||
}
|
||||
|
||||
|
@ -350,19 +239,15 @@ void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) {
|
|||
lst->assign(res.begin(), res.end());
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) {
|
||||
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) {
|
||||
auto func_graph = cnode->func_graph();
|
||||
auto graph_kernel_g = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
||||
GraphKernelInfo info;
|
||||
info.is_before_kernel_select = is_before_kernel_select;
|
||||
GetGraphKernelInfo(graph_kernel_g, &info);
|
||||
auto mng = func_graph->manager();
|
||||
// Search fusable nodes according input direction.
|
||||
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, &info, std::placeholders::_1);
|
||||
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1);
|
||||
auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
|
||||
std::reverse(used_nodes.begin(), used_nodes.end());
|
||||
// Search fusable nodes according output direction.
|
||||
auto include_func_backward = std::bind(IncludeFusedBasicOpBackward, cnode, &info, std::placeholders::_1);
|
||||
auto include_func_backward = std::bind(IncludeFusedBasicOpBackward, cnode, std::placeholders::_1);
|
||||
auto user_nodes = DeepUsersSearch(cnode, include_func_backward, mng);
|
||||
|
||||
used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end());
|
||||
|
@ -373,7 +258,7 @@ std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, bool is_before_ker
|
|||
return used_nodes;
|
||||
}
|
||||
|
||||
bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph, bool is_before_kernel_select) {
|
||||
bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
bool changed = false;
|
||||
auto &todos = kernel_graph->execution_order();
|
||||
|
@ -392,19 +277,19 @@ bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
|||
}
|
||||
}
|
||||
|
||||
auto fuse_nodes = FindFuseCNodes(node, is_before_kernel_select);
|
||||
auto fuse_nodes = FindFuseCNodes(node);
|
||||
if (fuse_nodes.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
changed = true;
|
||||
|
||||
FuseNodesToSubGraph(fuse_nodes, kernel_graph, "", is_before_kernel_select);
|
||||
FuseNodesToSubGraph(fuse_nodes, kernel_graph, "");
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool CompositeOpsFusion::Run(const FuncGraphPtr &func_graph) {
|
||||
return FuseCompositeOps(std::dynamic_pointer_cast<session::KernelGraph>(func_graph), false);
|
||||
return FuseCompositeOps(std::dynamic_pointer_cast<session::KernelGraph>(func_graph));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,25 +26,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
enum GraphKernelType {
|
||||
ELEWISE = 0, // only contain elewise basic ops
|
||||
REDUCE, // contain reduce ops
|
||||
CUBE, // contain cube ops
|
||||
};
|
||||
|
||||
struct GraphKernelInfo {
|
||||
GraphKernelType op_type = ELEWISE;
|
||||
bool is_before_kernel_select = false;
|
||||
int reduce_op_num = 0;
|
||||
int cal_step = 0;
|
||||
std::string origin_composite_name = "";
|
||||
};
|
||||
|
||||
// when composite fuse composite the cal step is greate than this number, not fuse
|
||||
#if ENABLE_D
|
||||
const int MAX_REDUCE_OP_FUSION_CAL_STEP = 5;
|
||||
const int MAX_REDUCE_OP_FUSION_REDUCE_NUM = 2;
|
||||
#endif
|
||||
const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward",
|
||||
"LambNextMV", "LambUpdateWithLR"};
|
||||
|
||||
|
@ -52,7 +33,7 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bo
|
|||
|
||||
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst);
|
||||
|
||||
bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph, bool is_before_kernel_select = false);
|
||||
bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
|
||||
class CompositeOpsFusion : public Pass {
|
||||
public:
|
||||
|
|
|
@ -176,7 +176,7 @@ AnfNodePtr GraphKernelExpander::CreateExpandGraphKernel(const FuncGraphPtr &func
|
|||
EliminateRedundantParameters(new_func_graph, &inputs);
|
||||
kernel::GetValidKernelNodes(new_func_graph, &kernel_nodes);
|
||||
kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs);
|
||||
auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs, false);
|
||||
auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs);
|
||||
SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(node));
|
||||
std::string graph_kernel_flag;
|
||||
std::for_each(kernel_nodes.begin(), kernel_nodes.end(), [&graph_kernel_flag](const AnfNodePtr &node) {
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "vm/segment_runner.h"
|
||||
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||
#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "backend/optimizer/pass/const_input_to_attr_registry.h"
|
||||
|
@ -445,7 +446,7 @@ AnfNodePtrList GetExpandOuts(const AnfNodePtrList &outs) {
|
|||
}
|
||||
|
||||
AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &func_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
|
||||
const AnfNodePtrList &outputs, bool is_before_kernel_select) {
|
||||
const AnfNodePtrList &outputs) {
|
||||
auto func_node = NewValueNode(fg);
|
||||
std::vector<AnfNodePtr> fn_inputs;
|
||||
fn_inputs.push_back(func_node);
|
||||
|
@ -467,9 +468,6 @@ AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &func_graph, const FuncGraphPtr
|
|||
auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
|
||||
auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second);
|
||||
fg->parameters()[i]->set_abstract(input_abs);
|
||||
if (is_before_kernel_select) {
|
||||
fg->parameters()[i]->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
}
|
||||
}
|
||||
return fuse_cnode;
|
||||
}
|
||||
|
@ -529,8 +527,7 @@ void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_f
|
|||
}
|
||||
|
||||
void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
||||
const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix,
|
||||
bool is_before_kernel_select) {
|
||||
const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix) {
|
||||
if (fuse_nodes.empty()) {
|
||||
return;
|
||||
}
|
||||
|
@ -547,10 +544,8 @@ void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
|||
AnfNodePtrList outputs;
|
||||
|
||||
std::tie(fg, inputs, outputs) = MixedNodesTransToGraph(fuse_nodes, &src_outputs);
|
||||
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs, is_before_kernel_select);
|
||||
if (!is_before_kernel_select) {
|
||||
SetNewKernelInfo(fuse_new_node, fg, inputs, outputs, AnfAlgo::GetProcessor(fuse_nodes[0]));
|
||||
}
|
||||
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs);
|
||||
SetNewKernelInfo(fuse_new_node, fg, inputs, outputs, AnfAlgo::GetProcessor(fuse_nodes[0]));
|
||||
// Handle get-item probleam.
|
||||
ReplaceNewFuseCNode(kernel_graph, fuse_new_node, src_outputs);
|
||||
|
||||
|
@ -702,9 +697,20 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo
|
|||
|
||||
std::unordered_set<PrimitivePtr> GetExpandOps() {
|
||||
std::unordered_set<PrimitivePtr> expand_ops = {
|
||||
prim::kPrimSquare, prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, prim::kPrimGelu,
|
||||
prim::kPrimGeluGrad, prim::kPrimFusedAdam, prim::kPrimFusedAdamWeightDecay, prim::kPrimTanhGrad,
|
||||
prim::kPrimReduceMean, prim::kPrimMaximumGrad, prim::kPrimMinimumGrad};
|
||||
prim::kPrimSquare,
|
||||
#if ENABLE_GPU
|
||||
prim::kPrimBiasAdd,
|
||||
prim::kPrimBiasAddGrad,
|
||||
prim::kPrimGelu,
|
||||
prim::kPrimGeluGrad,
|
||||
prim::kPrimFusedAdam,
|
||||
prim::kPrimFusedAdamWeightDecay,
|
||||
prim::kPrimTanhGrad,
|
||||
prim::kPrimReduceMean,
|
||||
prim::kPrimMaximumGrad,
|
||||
prim::kPrimMinimumGrad
|
||||
#endif
|
||||
};
|
||||
return expand_ops;
|
||||
}
|
||||
|
||||
|
@ -725,16 +731,54 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p
|
|||
}
|
||||
|
||||
std::vector<PrimitivePtr> GetFusibleOpList() {
|
||||
#if ENABLE_D
|
||||
std::vector<PrimitivePtr> fusible_basic_ops = {
|
||||
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
|
||||
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast,
|
||||
prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect,
|
||||
prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape,
|
||||
prim::kPrimTranspose};
|
||||
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
|
||||
prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
||||
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape,
|
||||
prim::kPrimTranspose, prim::kPrimCast};
|
||||
#elif ENABLE_GPU
|
||||
std::vector<PrimitivePtr> fusible_basic_ops = {
|
||||
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
|
||||
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
||||
prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater,
|
||||
prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
|
||||
prim::kPrimCast};
|
||||
#else
|
||||
std::vector<PrimitivePtr> fusible_basic_ops;
|
||||
#endif
|
||||
return fusible_basic_ops;
|
||||
}
|
||||
|
||||
bool CheckProcessor(const AnfNodePtr &node, kernel::Processor processor = kernel::Processor::AICORE) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
auto node_kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
||||
if (node_kernel_info == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto node_build_info = node_kernel_info->GetMutableSelectKernelBuildInfo();
|
||||
if (node_build_info == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return node_build_info->processor() == processor;
|
||||
}
|
||||
|
||||
bool IsBasicFuseOp(const AnfNodePtr &node) {
|
||||
std::vector<PrimitivePtr> basic_ops = GetFusibleOpList();
|
||||
#if ENABLE_D
|
||||
if (!CheckProcessor(node)) {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
return std::any_of(basic_ops.begin(), basic_ops.end(),
|
||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||
}
|
||||
|
||||
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
|
|
@ -45,12 +45,11 @@ void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const
|
|||
const AnfNodePtrList &outputs, kernel::Processor processor);
|
||||
AnfNodePtrList GetExpandOuts(const AnfNodePtrList &outs);
|
||||
AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
|
||||
const AnfNodePtrList &outputs, bool is_before_kernel_select);
|
||||
const AnfNodePtrList &outputs);
|
||||
void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode,
|
||||
const AnfNodePtrList &outputs);
|
||||
void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
||||
const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix,
|
||||
bool is_before_kernel_select);
|
||||
const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix);
|
||||
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc);
|
||||
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
|
||||
std::map<std::string, AnfNodePtr> *address_node_map);
|
||||
|
@ -59,6 +58,7 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo
|
|||
std::unordered_set<PrimitivePtr> GetExpandOps();
|
||||
std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = "");
|
||||
std::vector<PrimitivePtr> GetFusibleOpList();
|
||||
bool IsBasicFuseOp(const AnfNodePtr &node);
|
||||
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -45,7 +45,7 @@ bool TensorPromotion::Run(const FuncGraphPtr &func_graph) {
|
|||
AnfNodePtrList inputs, outputs;
|
||||
inputs.insert(inputs.end(), args.begin() + 1, args.end());
|
||||
kernel::GetFuncGraphOutputNodes(fg, &outputs);
|
||||
auto new_cnode = CreateNewFuseCNode(func_graph, fg, inputs, outputs, false);
|
||||
auto new_cnode = CreateNewFuseCNode(func_graph, fg, inputs, outputs);
|
||||
SetNewKernelInfo(new_cnode, fg, inputs, outputs, AnfAlgo::GetProcessor(node));
|
||||
mng->Replace(node, new_cnode);
|
||||
changed = true;
|
||||
|
|
|
@ -41,6 +41,14 @@
|
|||
#include "debug/data_dump/dump_json_parser.h"
|
||||
#include "debug/tensor_load.h"
|
||||
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
|
||||
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
|
||||
#include "backend/optimizer/graph_kernel/tensor_promotion.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_cse.h"
|
||||
#include "backend/optimizer/graph_kernel/value_graph_binder.h"
|
||||
#include "backend/optimizer/graph_kernel/add_atomic_clean.h"
|
||||
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||
#include "debug/data_dump/e2e_dump_util.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "debug/dump_proto.h"
|
||||
|
@ -291,8 +299,6 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
|
|||
MS_EXCEPTION_IF_NULL(child_graph);
|
||||
MS_LOG(INFO) << "CompileChildGraph " << child_graph->ToString();
|
||||
opt::AscendBackendIRFusionOptimization(child_graph);
|
||||
opt::AscendBackendFuseBasicOpt(child_graph, true);
|
||||
opt::AscendBackendGraphKernelOpt(child_graph, true);
|
||||
child_graph->SetExecOrderByDefault();
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
|
@ -466,13 +472,35 @@ void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_
|
|||
MS_LOG(INFO) << "HardwareOptimize start!";
|
||||
opt::AscendBackendOptimization(kernel_graph);
|
||||
opt::AscendGraphKernelCommonProcess(kernel_graph);
|
||||
opt::AscendBackendFuseBasicOpt(kernel_graph, false);
|
||||
opt::AscendBackendAddAtomicClean(kernel_graph);
|
||||
GraphKernelOptimize(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
MS_LOG(INFO) << "HardwareOptimize Finish!";
|
||||
}
|
||||
|
||||
void AscendSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
|
||||
return;
|
||||
}
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm");
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelExpander>());
|
||||
pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
|
||||
pm->AddPass(std::make_shared<opt::CompositeOpsFusion>());
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
|
||||
pm->AddPass(std::make_shared<opt::TensorPromotion>());
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelSplitter>());
|
||||
// After Simplify and Splitter, a lot of redundant getitem/maketuple
|
||||
// will be exposed, use GetitemTuple Pass to delete them.
|
||||
pm->AddPass(std::make_shared<opt::GetitemTuple>());
|
||||
pm->AddPass(std::make_shared<opt::BindValueToGraph>());
|
||||
pm->AddPass(std::make_shared<opt::CleanAddAtomic>());
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
}
|
||||
|
||||
void AscendSession::AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
opt::HideNopNode(kernel_graph.get());
|
||||
|
@ -865,8 +893,6 @@ void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<st
|
|||
memo->insert(graph.get());
|
||||
|
||||
opt::AscendBackendIRFusionOptimization(graph);
|
||||
opt::AscendBackendFuseBasicOpt(graph, true);
|
||||
opt::AscendBackendGraphKernelOpt(graph, true);
|
||||
graph->SetExecOrderByDefault();
|
||||
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
|
|
|
@ -87,6 +87,7 @@ class AscendSession : public SessionBasic {
|
|||
void InitRuntimeResource();
|
||||
void SelectKernel(const KernelGraph &kernel_graph) const;
|
||||
void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void AssignStream(NotNull<KernelGraphPtr> kernel_graph) const;
|
||||
|
|
Loading…
Reference in New Issue