forked from mindspore-Ecosystem/mindspore
!25410 add dump flag for fusion nodes
Merge pull request !25410 from yuchaojie/ir_fusion3
This commit is contained in:
commit
8bf7e28fa6
|
@ -85,7 +85,10 @@ CNodePtr CreateFusionOp(const std::vector<AnfNodePtr> &inputs_list, const std::v
|
|||
if (AnfAlgo::HasNodeAttr(kAttrFracZGroup, cnode)) {
|
||||
auto fracz_group = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrFracZGroup);
|
||||
fusion_op->set_attr(kAttrFracZGroup, MakeValue(fracz_group));
|
||||
break;
|
||||
}
|
||||
if (AnfAlgo::HasNodeAttr(kAttrDump, cnode)) {
|
||||
auto dump_flag = AnfAlgo::GetNodeAttr<string>(node, kAttrDump);
|
||||
fusion_op->set_attr(kAttrDump, MakeValue(dump_flag));
|
||||
}
|
||||
}
|
||||
std::vector<AnfNodePtr> fusion_inputs_list = inputs_list;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -71,7 +71,7 @@ const AnfNodePtr InsertPlaceholderForDynamicGRUV2::Process(const FuncGraphPtr &f
|
|||
if (kernel_graph == nullptr) {
|
||||
new_node = std::make_shared<CNode>(*cnode);
|
||||
} else {
|
||||
new_node = kernel_graph->NewCNode(cnode);
|
||||
new_node = NewCNode(cnode, kernel_graph);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_inputs(new_inputs);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -68,7 +68,7 @@ const AnfNodePtr InsertPlaceholderForDynamicRNN::Process(const FuncGraphPtr &fun
|
|||
if (kernel_graph == nullptr) {
|
||||
new_node = std::make_shared<CNode>(*cnode);
|
||||
} else {
|
||||
new_node = kernel_graph->NewCNode(cnode);
|
||||
new_node = NewCNode(cnode, kernel_graph);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_inputs(new_inputs);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -15,15 +15,11 @@
|
|||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h"
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
using OutputInfo =
|
||||
std::tuple<std::vector<TypeId>, std::vector<std::vector<size_t>>, std::vector<std::string>, std::vector<TypeId>>;
|
||||
OutputInfo GetNodeOutputInfo(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::vector<TypeId> output_infer_dtype;
|
||||
|
@ -87,9 +83,12 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AnfNodePtr &concat, con
|
|||
builder.SetOutputsDeviceType(outputs_device_type);
|
||||
return builder.Build();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const OutputInfo &output_info,
|
||||
const std::vector<AnfNodePtr> &new_tuple_getitems, int64_t rank_size) {
|
||||
AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const OutputInfo &output_info,
|
||||
const std::vector<AnfNodePtr> &new_tuple_getitems,
|
||||
int64_t rank_size) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
|
||||
size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
|
||||
|
@ -98,7 +97,7 @@ AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
for (size_t j = 0, idx = i; j < LongToSize(rank_size); ++j, idx += inputs_size) {
|
||||
concat_inputs.push_back(new_tuple_getitems[idx]);
|
||||
}
|
||||
auto concat = func_graph->NewCNode(concat_inputs);
|
||||
auto concat = NewCNode(concat_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(concat);
|
||||
MS_EXCEPTION_IF_NULL(new_tuple_getitems[i]);
|
||||
const std::vector<TypeId> &dtypes = {std::get<0>(output_info)[i]};
|
||||
|
@ -122,7 +121,6 @@ AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
||||
return make_tuple;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef ConcatOutputsForAllGather::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -18,11 +18,16 @@
|
|||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/ascend_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
using OutputInfo =
|
||||
std::tuple<std::vector<TypeId>, std::vector<std::vector<size_t>>, std::vector<std::string>, std::vector<TypeId>>;
|
||||
|
||||
class ConcatOutputsForAllGather : public PatternProcessPass {
|
||||
public:
|
||||
explicit ConcatOutputsForAllGather(bool multigraph = true)
|
||||
|
@ -33,6 +38,9 @@ class ConcatOutputsForAllGather : public PatternProcessPass {
|
|||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const OutputInfo &output_info, const std::vector<AnfNodePtr> &new_tuple_getitems,
|
||||
int64_t rank_size) const;
|
||||
KernelSelectPtr kernel_select_;
|
||||
};
|
||||
} // namespace opt
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -15,7 +15,6 @@
|
|||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
@ -34,14 +33,15 @@ const BaseRef InsertPadForNMSWithMask::DefinePattern() const {
|
|||
return VectorRef({prim::kPrimNMSWithMask, Xs});
|
||||
}
|
||||
|
||||
AnfNodePtr InsertPadToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const TypeId &origin_type,
|
||||
const std::vector<size_t> &origin_shape) {
|
||||
AnfNodePtr InsertPadForNMSWithMask::InsertPadToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
||||
const TypeId &origin_type,
|
||||
const std::vector<size_t> &origin_shape) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> new_pad_inputs;
|
||||
auto prim = std::make_shared<Primitive>(prim::kPrimPad->name());
|
||||
new_pad_inputs.push_back(NewValueNode(prim));
|
||||
new_pad_inputs.push_back(input);
|
||||
CNodePtr pad = func_graph->NewCNode(new_pad_inputs);
|
||||
CNodePtr pad = NewCNode(new_pad_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(pad);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, pad.get());
|
||||
return pad;
|
||||
|
@ -81,7 +81,7 @@ const AnfNodePtr InsertPadForNMSWithMask::Process(const FuncGraphPtr &func_graph
|
|||
if (kernel_graph == nullptr) {
|
||||
new_node = std::make_shared<CNode>(*cnode);
|
||||
} else {
|
||||
new_node = kernel_graph->NewCNode(cnode);
|
||||
new_node = NewCNode(cnode, kernel_graph);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_inputs(new_inputs);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,6 +16,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H
|
||||
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
|
||||
|
@ -28,6 +29,10 @@ class InsertPadForNMSWithMask : public PatternProcessPass {
|
|||
~InsertPadForNMSWithMask() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
AnfNodePtr InsertPadToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const TypeId &origin_type,
|
||||
const std::vector<size_t> &origin_shape) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -29,7 +29,7 @@ std::vector<AnfNodePtr> SplitInputsForReduceScatter::InsertSplitForInput(const F
|
|||
for (size_t i = 0; i < inputs_size; i++) {
|
||||
std::vector<AnfNodePtr> split_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
|
||||
split_inputs.push_back(AnfAlgo::GetInputNode(node, i));
|
||||
auto split = func_graph->NewCNode(split_inputs);
|
||||
auto split = NewCNode(split_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(split);
|
||||
std::vector<TypeId> dtypes(rank_size, AnfAlgo::GetPrevNodeOutputInferDataType(node, i));
|
||||
std::vector<std::vector<size_t>> shapes;
|
||||
|
@ -68,7 +68,7 @@ AnfNodePtr SplitInputsForReduceScatter::RearrangeInputsForReduceScatter(const Fu
|
|||
reduce_scatter_inputs.push_back(inputs[idx]);
|
||||
}
|
||||
}
|
||||
auto reduce_scatter = func_graph->NewCNode(reduce_scatter_inputs);
|
||||
auto reduce_scatter = NewCNode(reduce_scatter_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(reduce_scatter);
|
||||
reduce_scatter->set_abstract(node->abstract());
|
||||
AnfAlgo::CopyNodeAttrs(node, reduce_scatter);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -162,6 +162,7 @@ CNodePtr DealRefAndSpiltUnSupportedTransdata::MakeDependency(const CNodePtr &get
|
|||
}
|
||||
return func_graph->NewCNode(depend_nodes);
|
||||
}
|
||||
|
||||
CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefForMultipleOutput(
|
||||
const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode, const std::shared_ptr<kernel::OpInfo> &op_info) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -172,7 +173,7 @@ CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefForMultipleOutput(
|
|||
if (!update_states.empty()) {
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
cnode = kernel_graph->NewCNode(orig_cnode);
|
||||
cnode = NewCNode(orig_cnode, kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
cnode->set_inputs(orig_cnode->inputs());
|
||||
for (auto &update_state : update_states) {
|
||||
|
|
|
@ -36,9 +36,24 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kType32Len = 4;
|
||||
constexpr size_t kType64Len = 8;
|
||||
|
||||
void UpdateDumpFlag(const AnfNodePtr &node, const std::vector<AnfNodePtr> &orig_nodes) {
|
||||
for (auto &orig_node : orig_nodes) {
|
||||
if (!orig_node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto orig_cnode = orig_node->cast<CNodePtr>();
|
||||
if (AnfAlgo::HasNodeAttr(kAttrDump, orig_cnode)) {
|
||||
AnfAlgo::CopyNodeAttr(kAttrDump, orig_cnode, node);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<int64_t> Convert2Int(const std::vector<size_t> &v) {
|
||||
std::vector<int64_t> result;
|
||||
(void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToInt);
|
||||
|
@ -101,6 +116,21 @@ bool UnVisited(const BaseRef &n) {
|
|||
return false;
|
||||
}
|
||||
|
||||
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg,
|
||||
const std::vector<AnfNodePtr> &orig_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto node = fg->NewCNode(inputs);
|
||||
UpdateDumpFlag(node, orig_nodes);
|
||||
return node;
|
||||
}
|
||||
|
||||
CNodePtr NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg, const std::vector<AnfNodePtr> &orig_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto node = fg->NewCNode(cnode);
|
||||
UpdateDumpFlag(node, orig_nodes);
|
||||
return node;
|
||||
}
|
||||
|
||||
CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, size_t input_size) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
|
@ -654,7 +684,7 @@ bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
|
|||
}
|
||||
|
||||
namespace {
|
||||
ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
|
||||
ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp, PrimitiveVarMap *primitive_vars) {
|
||||
if (utils::isa<int>(sexp)) {
|
||||
return NewValueNode(utils::cast<int>(sexp));
|
||||
}
|
||||
|
@ -668,7 +698,16 @@ ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
|
|||
return NewValueNode(utils::cast<bool>(sexp));
|
||||
}
|
||||
if (utils::isa<ValuePtr>(sexp)) {
|
||||
return NewValueNode(utils::cast<ValuePtr>(sexp));
|
||||
auto value = utils::cast<ValuePtr>(sexp);
|
||||
if (utils::isa<PrimitivePtr>(sexp)) {
|
||||
auto prim = utils::cast<PrimitivePtr>(sexp);
|
||||
if (primitive_vars->find(prim) != primitive_vars->end()) {
|
||||
prim = std::make_shared<Primitive>(prim->name());
|
||||
value = prim;
|
||||
}
|
||||
(*primitive_vars)[prim] = std::make_shared<Var>(prim);
|
||||
}
|
||||
return NewValueNode(value);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -831,7 +870,7 @@ AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap
|
|||
if (utils::isa<AnfNodePtr>(sexp)) {
|
||||
return utils::cast<AnfNodePtr>(sexp);
|
||||
}
|
||||
auto value_node = CreateValueNodeWithSexp(sexp);
|
||||
auto value_node = CreateValueNodeWithSexp(sexp, primitive_vars);
|
||||
if (value_node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString();
|
||||
}
|
||||
|
|
|
@ -137,6 +137,11 @@ bool UnVisited(const BaseRef &n);
|
|||
|
||||
bool Visited(const BaseRef &n);
|
||||
|
||||
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg,
|
||||
const std::vector<AnfNodePtr> &orig_nodes);
|
||||
|
||||
CNodePtr NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg, const std::vector<AnfNodePtr> &orig_nodes);
|
||||
|
||||
// check if the input node is CNode, then check it's input_size, return CNodePtr if check success.
|
||||
CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, size_t input_size);
|
||||
|
||||
|
|
|
@ -31,7 +31,8 @@ PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph)
|
|||
: NodePass(name),
|
||||
multigraph_(multigraph),
|
||||
pattern_engine_(PatternEngine(std::make_shared<Visitor>())),
|
||||
primitive_vars_(std::make_shared<PrimitiveVarMap>()) {}
|
||||
primitive_vars_(std::make_shared<PrimitiveVarMap>()),
|
||||
equiv_(std::make_shared<Equiv>()) {}
|
||||
|
||||
const BaseRef PatternProcessPass::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
|
@ -50,9 +51,10 @@ AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNode
|
|||
|
||||
auto primitive = GetCNodePrimitive(pattern_);
|
||||
if (IsPrimitiveCNode(node, primitive)) {
|
||||
auto empty_equiv = std::make_shared<Equiv>();
|
||||
MS_EXCEPTION_IF_NULL(primitive_vars_);
|
||||
EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, empty_equiv);
|
||||
MS_EXCEPTION_IF_NULL(equiv_);
|
||||
equiv_->clear();
|
||||
EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, equiv_);
|
||||
if (equiv != nullptr && !equiv->empty()) {
|
||||
return Process(func_graph, node, equiv);
|
||||
}
|
||||
|
@ -60,21 +62,62 @@ AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNode
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> PatternProcessPass::GetOrigNodes() const {
|
||||
std::vector<AnfNodePtr> orig_nodes;
|
||||
for (auto &prim_var : *primitive_vars_) {
|
||||
if (equiv_->find(prim_var.second) == equiv_->end()) {
|
||||
continue;
|
||||
}
|
||||
auto baseref = (*equiv_)[prim_var.second];
|
||||
if (!utils::isa<CNode>(baseref)) {
|
||||
continue;
|
||||
}
|
||||
auto node = utils::cast<AnfNodePtr>(baseref);
|
||||
orig_nodes.push_back(node);
|
||||
}
|
||||
return orig_nodes;
|
||||
}
|
||||
|
||||
CNodePtr PatternProcessPass::NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) const {
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto orig_nodes = GetOrigNodes();
|
||||
return opt::NewCNode(inputs, fg, orig_nodes);
|
||||
}
|
||||
|
||||
CNodePtr PatternProcessPass::NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg) const {
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto orig_nodes = GetOrigNodes();
|
||||
return opt::NewCNode(cnode, fg, orig_nodes);
|
||||
}
|
||||
|
||||
bool MultipleOutputPatternProcessPass::MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
VarPtr fg = std::make_shared<Var>("RootG");
|
||||
auto empty_equiv = std::make_shared<Equiv>();
|
||||
MS_EXCEPTION_IF_NULL(child_primitive_vars_);
|
||||
MS_EXCEPTION_IF_NULL(child_equiv_);
|
||||
EquivPtr another_equiv =
|
||||
child_pattern_engine_.Match(SexpToNode(DefineAnotherPattern(), fg, child_primitive_vars_.get(), true), node,
|
||||
*child_primitive_vars_, empty_equiv);
|
||||
*child_primitive_vars_, child_equiv_);
|
||||
if (another_equiv != nullptr && !another_equiv->empty()) {
|
||||
return IsShareNodes(equiv, another_equiv);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> MultipleOutputPatternProcessPass::GetOrigNodes() const {
|
||||
std::vector<AnfNodePtr> orig_nodes = PatternProcessPass::GetOrigNodes();
|
||||
for (auto &prim_var : *child_primitive_vars_) {
|
||||
auto baseref = (*child_equiv_)[prim_var.second];
|
||||
if (!utils::isa<CNode>(baseref)) {
|
||||
continue;
|
||||
}
|
||||
auto node = utils::cast<AnfNodePtr>(baseref);
|
||||
orig_nodes.push_back(node);
|
||||
}
|
||||
return orig_nodes;
|
||||
}
|
||||
|
||||
void GraphOptimizer::AddPassManager(const PassManagerPtr &pass_manager) {
|
||||
if (pass_manager != nullptr) {
|
||||
pass_managers_.push_back(pass_manager);
|
||||
|
|
|
@ -41,6 +41,11 @@ class PatternProcessPass : public NodePass {
|
|||
virtual const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const = 0;
|
||||
virtual const BaseRef DefinePattern() const;
|
||||
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
|
||||
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) const;
|
||||
CNodePtr NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg) const;
|
||||
|
||||
protected:
|
||||
virtual std::vector<AnfNodePtr> GetOrigNodes() const;
|
||||
|
||||
private:
|
||||
void Build();
|
||||
|
@ -49,6 +54,7 @@ class PatternProcessPass : public NodePass {
|
|||
bool multigraph_ = true;
|
||||
PatternEngine pattern_engine_;
|
||||
PrimitiveVarMapPtr primitive_vars_;
|
||||
EquivPtr equiv_;
|
||||
};
|
||||
|
||||
class MultipleOutputPatternProcessPass : public PatternProcessPass {
|
||||
|
@ -56,7 +62,8 @@ class MultipleOutputPatternProcessPass : public PatternProcessPass {
|
|||
explicit MultipleOutputPatternProcessPass(const std::string &name = "", bool multigraph = true)
|
||||
: PatternProcessPass(name, multigraph),
|
||||
child_pattern_engine_(PatternEngine(std::make_shared<Visitor>())),
|
||||
child_primitive_vars_(std::make_shared<PrimitiveVarMap>()) {}
|
||||
child_primitive_vars_(std::make_shared<PrimitiveVarMap>()),
|
||||
child_equiv_(std::make_shared<Equiv>()) {}
|
||||
~MultipleOutputPatternProcessPass() override = default;
|
||||
virtual BaseRef DefineAnotherPattern() const = 0;
|
||||
// check two patterns whether share the same nodes or not
|
||||
|
@ -64,8 +71,10 @@ class MultipleOutputPatternProcessPass : public PatternProcessPass {
|
|||
|
||||
protected:
|
||||
bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const;
|
||||
std::vector<AnfNodePtr> GetOrigNodes() const override;
|
||||
PatternEngine child_pattern_engine_;
|
||||
PrimitiveVarMapPtr child_primitive_vars_;
|
||||
EquivPtr child_equiv_;
|
||||
};
|
||||
|
||||
class GraphOptimizer {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -25,6 +25,7 @@
|
|||
#include "runtime/device/kernel_info.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -318,16 +319,18 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr
|
|||
if (end_index >= communication_op_info.communication_op_nodes.size()) {
|
||||
MS_LOG(EXCEPTION) << "end index out of communication_op_nodes size";
|
||||
}
|
||||
std::vector<AnfNodePtr> orig_nodes;
|
||||
for (size_t idx = start_index; idx <= end_index; ++idx) {
|
||||
auto cnode = communication_op_info.communication_op_nodes[idx];
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (idx != start_index) {
|
||||
AdjustAllReduceInputWithLoad(cnode);
|
||||
}
|
||||
fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
(void)fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
(void)orig_nodes.emplace_back(cnode);
|
||||
}
|
||||
CheckInputs(fusion_inputs);
|
||||
AnfNodePtr fused_node = func_graph->NewCNode(fusion_inputs);
|
||||
AnfNodePtr fused_node = NewCNode(fusion_inputs, func_graph, orig_nodes);
|
||||
MS_EXCEPTION_IF_NULL(fused_node);
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
|
@ -450,7 +453,7 @@ bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) {
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
const float input_grad_size_num = 0.0;
|
||||
const float input_grad_time_num = 0.0;
|
||||
// divide candidate fusion groups with same (group,op,fusion) attrs, fusion==0 means not fusion
|
||||
// divide candidate fusion groups with same (group,op,fusion,dtype) attrs, fusion==0 means not fusion
|
||||
std::unordered_map<std::string, CommunicationOpInfo> candidate_groups;
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -58,8 +58,10 @@ AnfNodePtr CreateTensorInput(const KernelGraphPtr &kernel_graph, const AnfNodePt
|
|||
tensor_input->set_scope(input_node->scope());
|
||||
return tensor_input;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
AnfNodePtr ConvertConstInputToTensorInput::ConstInputToTensorInput(const FuncGraphPtr &func_graph,
|
||||
const CNodePtr &cnode) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const std::set<std::string> no_need_to_convert_nodes = {kStackOpName};
|
||||
|
@ -89,7 +91,7 @@ AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePt
|
|||
}
|
||||
if (need_update) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto new_cnode = func_graph->NewCNode(new_inputs);
|
||||
auto new_cnode = NewCNode(new_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimDepend)) {
|
||||
new_cnode->set_abstract(new_inputs[1]->abstract());
|
||||
|
@ -105,7 +107,6 @@ AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePt
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -29,6 +29,9 @@ class ConvertConstInputToTensorInput : public PatternProcessPass {
|
|||
: PatternProcessPass("convert_const_input_to_tensor_input", multigraph) {}
|
||||
~ConvertConstInputToTensorInput() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -81,7 +81,7 @@ const AnfNodePtr ConvertConstScalarToTensor::Process(const FuncGraphPtr &func_gr
|
|||
if (kernel_graph == nullptr || !input_changed) {
|
||||
return nullptr;
|
||||
}
|
||||
return kernel_graph->NewCNode(cnode);
|
||||
return NewCNode(cnode, kernel_graph);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -88,7 +88,7 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
|
|||
if (kernel_graph == nullptr || !cnode_input_changed) {
|
||||
return nullptr;
|
||||
}
|
||||
return kernel_graph->NewCNode(cnode);
|
||||
return NewCNode(cnode, kernel_graph);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,7 +26,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
using KernelWithIndex = std::pair<CNodePtr, size_t>;
|
||||
namespace {
|
||||
CNodePtr GetRealPrevCNode(const AnfNodePtr &node, size_t index, std::vector<KernelWithIndex> *pass_vector) {
|
||||
MS_EXCEPTION_IF_NULL(pass_vector);
|
||||
|
@ -78,9 +77,11 @@ bool TransDataOpEliminateCondition(const CNodePtr &node1, const CNodePtr &node2)
|
|||
AnfAlgo::GetOutputFormat(node1, 0) == AnfAlgo::GetInputFormat(node2, 0) &&
|
||||
kernel::IsSameShape(AnfAlgo::GetInputDeviceShape(node2, 0), AnfAlgo::GetOutputDeviceShape(node1, 0));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const AnfNodePtr ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const CNodePtr &prev_cnode,
|
||||
std::vector<KernelWithIndex> *pass_vector) {
|
||||
const AnfNodePtr EliminateRedundantOp::ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const CNodePtr &prev_cnode,
|
||||
std::vector<KernelWithIndex> *pass_vector) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(pass_vector);
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
|
@ -113,7 +114,7 @@ const AnfNodePtr ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNode
|
|||
MS_LOG(ERROR) << "pass_size should >= 2";
|
||||
}
|
||||
for (size_t idx = pass_size - kOffset; idx > 0; --idx) {
|
||||
auto new_node = func_graph->NewCNode((*pass_vector)[idx].first->inputs());
|
||||
auto new_node = NewCNode((*pass_vector)[idx].first->inputs(), func_graph);
|
||||
if (idx == pass_size - kOffset) {
|
||||
new_node->set_input((*pass_vector)[idx].second,
|
||||
(*pass_vector)[idx + 1].first->input((*pass_vector)[idx + 1].second));
|
||||
|
@ -125,7 +126,6 @@ const AnfNodePtr ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNode
|
|||
return (*pass_vector)[1].first;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void EliminateRedundantOp::Init() {
|
||||
(void)redundant_process_map_.emplace(std::pair<std::string, RedundantOpPair>(
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -29,6 +29,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
using ConditionFunc = std::function<bool(const CNodePtr &node1, const CNodePtr &node2)>;
|
||||
using RedundantOpPair = std::pair<std::string, ConditionFunc>;
|
||||
using KernelWithIndex = std::pair<CNodePtr, size_t>;
|
||||
|
||||
class EliminateRedundantOp : public PatternProcessPass {
|
||||
public:
|
||||
|
@ -41,6 +42,8 @@ class EliminateRedundantOp : public PatternProcessPass {
|
|||
private:
|
||||
void Init();
|
||||
const AnfNodePtr DoEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const;
|
||||
const AnfNodePtr ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const CNodePtr &prev_cnode, std::vector<KernelWithIndex> *pass_vector) const;
|
||||
std::unordered_map<std::string, RedundantOpPair> redundant_process_map_;
|
||||
};
|
||||
} // namespace opt
|
||||
|
|
|
@ -28,46 +28,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const int axis_input_index = 2;
|
||||
namespace {
|
||||
AnfNodePtr NewRankOp(const AnfNodePtr &cnode, const KernelGraphPtr &kernel_graph) {
|
||||
std::vector<AnfNodePtr> rank_inputs;
|
||||
auto prim = std::make_shared<Primitive>(prim::kPrimRank->name());
|
||||
rank_inputs.push_back(NewValueNode(prim));
|
||||
auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, 1);
|
||||
rank_inputs.push_back(prev_node.first);
|
||||
auto rank_op = kernel_graph->NewCNode(rank_inputs);
|
||||
MS_EXCEPTION_IF_NULL(rank_op);
|
||||
rank_op->set_abstract(prev_node.first->abstract());
|
||||
return rank_op;
|
||||
}
|
||||
|
||||
AnfNodePtr NewRangeOp(const AnfNodePtr &rank_op, const KernelGraphPtr &kernel_graph) {
|
||||
std::vector<AnfNodePtr> range_inputs;
|
||||
auto prim = std::make_shared<Primitive>(prim::kPrimRange->name());
|
||||
range_inputs.push_back(NewValueNode(prim));
|
||||
// "start"
|
||||
auto start_ = NewValueNode(SizeToLong(0));
|
||||
MS_EXCEPTION_IF_NULL(start_);
|
||||
auto imm_start = std::make_shared<Int64Imm>(SizeToLong(0));
|
||||
start_->set_abstract(std::make_shared<abstract::AbstractScalar>(imm_start));
|
||||
range_inputs.push_back(start_);
|
||||
|
||||
// "limit"
|
||||
range_inputs.push_back(rank_op);
|
||||
|
||||
// "delta"
|
||||
auto delta_ = NewValueNode(SizeToLong(1));
|
||||
MS_EXCEPTION_IF_NULL(delta_);
|
||||
auto imm_delta = std::make_shared<Int64Imm>(SizeToLong(1));
|
||||
delta_->set_abstract(std::make_shared<abstract::AbstractScalar>(imm_delta));
|
||||
range_inputs.push_back(delta_);
|
||||
// new range op
|
||||
auto range_op = kernel_graph->NewCNode(range_inputs);
|
||||
MS_EXCEPTION_IF_NULL(range_op);
|
||||
range_op->set_abstract(rank_op->abstract());
|
||||
return range_op;
|
||||
}
|
||||
const int axis_input_index = 2;
|
||||
|
||||
bool IsNeedComputeRank(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
@ -95,8 +57,48 @@ bool IsNeedComputeRank(const CNodePtr &cnode) {
|
|||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr InsertAssistNode(const CNodePtr &cnode, const KernelGraphPtr &kernel_graph) {
|
||||
AnfNodePtr ReduceSumOptimizer::NewRankOp(const AnfNodePtr &cnode, const KernelGraphPtr &kernel_graph) const {
|
||||
std::vector<AnfNodePtr> rank_inputs;
|
||||
auto prim = std::make_shared<Primitive>(prim::kPrimRank->name());
|
||||
rank_inputs.push_back(NewValueNode(prim));
|
||||
auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, 1);
|
||||
rank_inputs.push_back(prev_node.first);
|
||||
auto rank_op = NewCNode(rank_inputs, kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(rank_op);
|
||||
rank_op->set_abstract(prev_node.first->abstract());
|
||||
return rank_op;
|
||||
}
|
||||
|
||||
AnfNodePtr ReduceSumOptimizer::NewRangeOp(const AnfNodePtr &rank_op, const KernelGraphPtr &kernel_graph) const {
|
||||
std::vector<AnfNodePtr> range_inputs;
|
||||
auto prim = std::make_shared<Primitive>(prim::kPrimRange->name());
|
||||
range_inputs.push_back(NewValueNode(prim));
|
||||
// "start"
|
||||
auto start_ = NewValueNode(SizeToLong(0));
|
||||
MS_EXCEPTION_IF_NULL(start_);
|
||||
auto imm_start = std::make_shared<Int64Imm>(SizeToLong(0));
|
||||
start_->set_abstract(std::make_shared<abstract::AbstractScalar>(imm_start));
|
||||
range_inputs.push_back(start_);
|
||||
|
||||
// "limit"
|
||||
range_inputs.push_back(rank_op);
|
||||
|
||||
// "delta"
|
||||
auto delta_ = NewValueNode(SizeToLong(1));
|
||||
MS_EXCEPTION_IF_NULL(delta_);
|
||||
auto imm_delta = std::make_shared<Int64Imm>(SizeToLong(1));
|
||||
delta_->set_abstract(std::make_shared<abstract::AbstractScalar>(imm_delta));
|
||||
range_inputs.push_back(delta_);
|
||||
// new range op
|
||||
auto range_op = NewCNode(range_inputs, kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(range_op);
|
||||
range_op->set_abstract(rank_op->abstract());
|
||||
return range_op;
|
||||
}
|
||||
|
||||
AnfNodePtr ReduceSumOptimizer::InsertAssistNode(const CNodePtr &cnode, const KernelGraphPtr &kernel_graph) const {
|
||||
// the input dim is unknown, need rank + range, can not supported now; ;
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "Can not support the case that input is dim unknown and axis is empty or axis contain value less 0. node: "
|
||||
|
@ -108,7 +110,7 @@ AnfNodePtr InsertAssistNode(const CNodePtr &cnode, const KernelGraphPtr &kernel_
|
|||
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
||||
new_inputs.push_back(cnode->input(1));
|
||||
new_inputs.push_back(range_op);
|
||||
auto new_node = kernel_graph->NewCNode(cnode);
|
||||
auto new_node = NewCNode(cnode, kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_inputs(new_inputs);
|
||||
return new_node;
|
||||
|
@ -121,7 +123,7 @@ AnfNodePtr InsertAssistNode(const CNodePtr &cnode, const KernelGraphPtr &kernel_
|
|||
// 2: the value of axis_input contain the value less 0,
|
||||
// the new tensor of the new value node should be "shape.size() + the_old_value_less_0",
|
||||
// the shape is the first input'shape of ReduceSum;
|
||||
AnfNodePtr NewAssistValueNode(const CNodePtr &cnode, const KernelGraphPtr &kernel_graph) {
|
||||
AnfNodePtr ReduceSumOptimizer::NewAssistValueNode(const CNodePtr &cnode, const KernelGraphPtr &kernel_graph) const {
|
||||
// axis is a tuple ,maybe empty or contain a value less 0;
|
||||
auto axis_input = cnode->input(axis_input_index);
|
||||
if (IsValueNode<ValueTuple>(axis_input)) {
|
||||
|
@ -159,7 +161,7 @@ AnfNodePtr NewAssistValueNode(const CNodePtr &cnode, const KernelGraphPtr &kerne
|
|||
assist_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kInt64, axes_value));
|
||||
auto assist_value_node = kernel_graph->NewValueNode(assist_node);
|
||||
new_inputs.push_back(assist_value_node);
|
||||
auto new_node = kernel_graph->NewCNode(cnode);
|
||||
auto new_node = NewCNode(cnode, kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_inputs(new_inputs);
|
||||
return new_node;
|
||||
|
@ -167,7 +169,6 @@ AnfNodePtr NewAssistValueNode(const CNodePtr &cnode, const KernelGraphPtr &kerne
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const AnfNodePtr ReduceSumOptimizer::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
|
|
|
@ -26,6 +26,12 @@ class ReduceSumOptimizer : public PatternProcessPass {
|
|||
~ReduceSumOptimizer() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
AnfNodePtr NewRankOp(const AnfNodePtr &cnode, const KernelGraphPtr &kernel_graph) const;
|
||||
AnfNodePtr NewRangeOp(const AnfNodePtr &rank_op, const KernelGraphPtr &kernel_graph) const;
|
||||
AnfNodePtr InsertAssistNode(const CNodePtr &cnode, const KernelGraphPtr &kernel_graph) const;
|
||||
AnfNodePtr NewAssistValueNode(const CNodePtr &cnode, const KernelGraphPtr &kernel_graph) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -2001,6 +2001,15 @@ bool AnfRuntimeAlgorithm::GetBooleanAttr(const AnfNodePtr &node, const std::stri
|
|||
return AnfAlgo::GetNodeAttr<bool>(node, attr);
|
||||
}
|
||||
|
||||
std::optional<string> AnfRuntimeAlgorithm::GetDumpFlag(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr || !AnfAlgo::HasNodeAttr(kAttrDump, cnode)) {
|
||||
return std::optional<string>{};
|
||||
}
|
||||
return std::optional<string>{AnfAlgo::GetNodeAttr<string>(node, kAttrDump)};
|
||||
}
|
||||
|
||||
bool AnfRuntimeAlgorithm::HasDynamicShapeFlag(const PrimitivePtr &prim) {
|
||||
auto get_bool_attr = [](const PrimitivePtr &primitive, const std::string &attr_name) -> bool {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <memory>
|
||||
#include <unordered_set>
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/dtype.h"
|
||||
#include "base/base.h"
|
||||
|
@ -288,6 +289,7 @@ class AnfRuntimeAlgorithm {
|
|||
static bool IsCondControlKernel(const CNodePtr &node);
|
||||
static bool IsIndependentNode(const CNodePtr &node);
|
||||
static bool GetBooleanAttr(const AnfNodePtr &node, const std::string &attr);
|
||||
static std::optional<string> GetDumpFlag(const AnfNodePtr &node);
|
||||
static void GetRealDynamicShape(const std::vector<size_t> &shape, NotNull<std::vector<int64_t> *> dynamic_shape);
|
||||
static std::vector<int64_t> GetInputMaxShape(const AnfNodePtr &anf_node, size_t index);
|
||||
static std::vector<int64_t> GetInputMinShape(const AnfNodePtr &anf_node, size_t index);
|
||||
|
|
|
@ -483,6 +483,7 @@ constexpr auto kAttrProfilingIterEnd = "PROFILING_ITER_END";
|
|||
constexpr auto kAttrHiddenSize = "hidden_size";
|
||||
constexpr auto kAttrInputSize = "input_size";
|
||||
constexpr auto kAttrDstType = "dst_type";
|
||||
constexpr auto kAttrDump = "dump";
|
||||
constexpr auto kAttrSkipNopOpAddr = "skip_nop_op_addr";
|
||||
constexpr auto kAttrFuncType = "func_type";
|
||||
|
||||
|
|
|
@ -1218,5 +1218,16 @@ int GetDataTypeFromAnfNode(const AnfNodePtr &anf_node, TypeId *type_id) {
|
|||
*type_id = type_ptr->type_id();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
// not implement for lite, just for api compatible
|
||||
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg,
|
||||
const std::vector<AnfNodePtr> &orig_nodes) {
|
||||
return fg->NewCNode(inputs);
|
||||
}
|
||||
|
||||
// not implement for lite, just for api compatible
|
||||
CNodePtr NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg, const std::vector<AnfNodePtr> &orig_nodes) {
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,14 +16,18 @@
|
|||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
|
||||
#include "common/common_test.h"
|
||||
#include "backend/optimizer/common/pattern_engine.h"
|
||||
#include "backend/optimizer/common/visit.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "base/base_ref.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "ir/anf.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
using PatternListType = std::initializer_list<BaseRef>;
|
||||
|
@ -32,15 +36,16 @@ bool Equal(const BaseRef &a, const BaseRef &b) { return a == b; }
|
|||
|
||||
class TestMatchEngine : public UT::Common {
|
||||
public:
|
||||
TestMatchEngine()
|
||||
: TU(std::make_shared<Visitor>()) {
|
||||
TestMatchEngine() : TU(std::make_shared<Visitor>()) {
|
||||
equiv_null = std::make_shared<Equiv>();
|
||||
fg = std::make_shared<FuncGraph>();
|
||||
};
|
||||
|
||||
public:
|
||||
PatternEngine TU;
|
||||
EquivPtr equiv_null;
|
||||
PrimitiveVarMap primitive_vars_null;
|
||||
FuncGraphPtr fg;
|
||||
};
|
||||
|
||||
TEST_F(TestMatchEngine, Var) {
|
||||
|
@ -215,4 +220,79 @@ TEST_F(TestMatchEngine, Match_CondVar) {
|
|||
equiv_null);
|
||||
ASSERT_EQ(d, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: Backend support dump flag
|
||||
/// Description: PatternEngine match var with primitive
|
||||
/// Expectation: Get correct Equiv map
|
||||
TEST_F(TestMatchEngine, Match_PrimVar) {
|
||||
VarPtr mul1 = std::make_shared<Var>(std::make_shared<Primitive>(kMulOpName));
|
||||
VarPtr mul2 = std::make_shared<Var>(std::make_shared<Primitive>(kMulOpName));
|
||||
VarPtr v1 = std::make_shared<Var>();
|
||||
VarPtr sv2 = std::make_shared<SeqVar>();
|
||||
auto pattern_ref = VectorRef({mul1, v1, VectorRef({mul2, sv2})});
|
||||
PrimitiveVarMapPtr primitive_vars = std::make_shared<PrimitiveVarMap>();
|
||||
auto pattern_node = opt::SexpToNode(pattern_ref, fg, primitive_vars.get(), true);
|
||||
ASSERT_EQ(primitive_vars->size(), std::size_t(2));
|
||||
|
||||
auto anode1 = std::make_shared<AnfNode>(fg);
|
||||
auto anode2 = std::make_shared<AnfNode>(fg);
|
||||
auto anode3 = std::make_shared<AnfNode>(fg);
|
||||
AnfNodePtr mul2_cnode = std::make_shared<CNode>(
|
||||
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), anode2, anode3}, fg);
|
||||
AnfNodePtr mul1_cnode = std::make_shared<CNode>(
|
||||
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), anode1, mul2_cnode}, fg);
|
||||
|
||||
EquivPtr d;
|
||||
equiv_null->clear();
|
||||
d = TU.Match(pattern_node, mul1_cnode, *primitive_vars, equiv_null);
|
||||
ASSERT_EQ(d->size(), std::size_t(4));
|
||||
ASSERT_EQ((*d)[mul2], mul2_cnode);
|
||||
ASSERT_EQ((*d)[mul1], mul1_cnode);
|
||||
|
||||
AnfNodePtr sub_cnode = std::make_shared<CNode>(
|
||||
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kSubOpName)), anode1, mul2_cnode}, fg);
|
||||
|
||||
equiv_null->clear();
|
||||
d = TU.Match(pattern_node, sub_cnode, *primitive_vars, equiv_null);
|
||||
ASSERT_EQ(d, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: Backend support dump flag
|
||||
/// Description: PatternEngine match primitive
|
||||
/// Expectation: Get correct Equiv map
|
||||
TEST_F(TestMatchEngine, Match_Prim) {
|
||||
VarPtr v1 = std::make_shared<Var>();
|
||||
VarPtr sv2 = std::make_shared<SeqVar>();
|
||||
auto pattern_ref = VectorRef({prim::kPrimMul, v1, VectorRef({prim::kPrimMul, sv2})});
|
||||
PrimitiveVarMapPtr primitive_vars = std::make_shared<PrimitiveVarMap>();
|
||||
auto pattern_node = opt::SexpToNode(pattern_ref, fg, primitive_vars.get(), true);
|
||||
ASSERT_EQ(primitive_vars->size(), std::size_t(2));
|
||||
|
||||
auto anode1 = std::make_shared<AnfNode>(fg);
|
||||
auto anode2 = std::make_shared<AnfNode>(fg);
|
||||
auto anode3 = std::make_shared<AnfNode>(fg);
|
||||
AnfNodePtr mul2_cnode = std::make_shared<CNode>(
|
||||
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), anode2, anode3}, fg);
|
||||
AnfNodePtr mul1_cnode = std::make_shared<CNode>(
|
||||
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), anode1, mul2_cnode}, fg);
|
||||
|
||||
EquivPtr d;
|
||||
equiv_null->clear();
|
||||
d = TU.Match(pattern_node, mul1_cnode, *primitive_vars, equiv_null);
|
||||
ASSERT_EQ(d->size(), std::size_t(4));
|
||||
for (auto &prim_var : *primitive_vars) {
|
||||
if (prim_var.first == prim::kPrimMul) {
|
||||
ASSERT_EQ((*d)[prim_var.second], mul1_cnode);
|
||||
} else {
|
||||
ASSERT_EQ((*d)[prim_var.second], mul2_cnode);
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr sub_cnode = std::make_shared<CNode>(
|
||||
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kSubOpName)), anode1, mul2_cnode}, fg);
|
||||
|
||||
equiv_null->clear();
|
||||
d = TU.Match(pattern_node, sub_cnode, *primitive_vars, equiv_null);
|
||||
ASSERT_EQ(d, nullptr);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "common/common_test.h"
|
||||
#define private public
|
||||
#define protected public
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#undef private
|
||||
#undef protected
|
||||
|
||||
#include "base/core_ops.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/value.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class TestPass : public PatternProcessPass {
|
||||
public:
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const { return nullptr; };
|
||||
};
|
||||
|
||||
class TestPatternProcessPass : public UT::Common {
|
||||
public:
|
||||
TestPatternProcessPass() : TU() { fg = std::make_shared<FuncGraph>(); };
|
||||
|
||||
public:
|
||||
TestPass TU;
|
||||
FuncGraphPtr fg;
|
||||
};
|
||||
|
||||
/// Feature: Backend support dump flag
|
||||
/// Description: Get orig nodes according to primitive_vars_ and equiv_
|
||||
/// Expectation: Get correct orig nodes
|
||||
TEST_F(TestPatternProcessPass, test_GetOrigNodes) {
|
||||
TU.primitive_vars_->clear();
|
||||
TU.equiv_->clear();
|
||||
VarPtr mul1 = std::make_shared<Var>(std::make_shared<Primitive>(kMulOpName));
|
||||
VarPtr v1 = std::make_shared<Var>();
|
||||
VarPtr v2 = std::make_shared<Var>();
|
||||
(*TU.primitive_vars_)[mul1->primitive()] = mul1;
|
||||
|
||||
auto mul1_node = std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(prim::kPrimMul)}, fg);
|
||||
auto anode1 = std::make_shared<AnfNode>(fg);
|
||||
auto anode2 = std::make_shared<AnfNode>(fg);
|
||||
(*TU.equiv_)[mul1] = mul1_node;
|
||||
(*TU.equiv_)[v1] = anode1;
|
||||
(*TU.equiv_)[v2] = anode2;
|
||||
|
||||
auto orig_nodes = TU.GetOrigNodes();
|
||||
ASSERT_EQ(orig_nodes.size(), std::size_t(1));
|
||||
ASSERT_EQ(orig_nodes[0], mul1_node);
|
||||
|
||||
VarPtr mul2 = std::make_shared<Var>(std::make_shared<Primitive>(kMulOpName));
|
||||
(*TU.primitive_vars_)[mul2->primitive()] = mul2;
|
||||
orig_nodes = TU.GetOrigNodes();
|
||||
ASSERT_EQ(orig_nodes.size(), std::size_t(1));
|
||||
|
||||
auto mul2_node = std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(prim::kPrimMul)}, fg);
|
||||
(*TU.equiv_)[mul2] = mul2_node;
|
||||
orig_nodes = TU.GetOrigNodes();
|
||||
ASSERT_EQ(orig_nodes.size(), std::size_t(2));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue