!25410 add dump flag for fusion nodes

Merge pull request !25410 from yuchaojie/ir_fusion3
This commit is contained in:
i-robot 2021-11-10 02:25:08 +00:00 committed by Gitee
commit 8bf7e28fa6
29 changed files with 417 additions and 102 deletions

View File

@ -85,7 +85,10 @@ CNodePtr CreateFusionOp(const std::vector<AnfNodePtr> &inputs_list, const std::v
if (AnfAlgo::HasNodeAttr(kAttrFracZGroup, cnode)) {
auto fracz_group = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrFracZGroup);
fusion_op->set_attr(kAttrFracZGroup, MakeValue(fracz_group));
break;
}
if (AnfAlgo::HasNodeAttr(kAttrDump, cnode)) {
auto dump_flag = AnfAlgo::GetNodeAttr<string>(node, kAttrDump);
fusion_op->set_attr(kAttrDump, MakeValue(dump_flag));
}
}
std::vector<AnfNodePtr> fusion_inputs_list = inputs_list;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -71,7 +71,7 @@ const AnfNodePtr InsertPlaceholderForDynamicGRUV2::Process(const FuncGraphPtr &f
if (kernel_graph == nullptr) {
new_node = std::make_shared<CNode>(*cnode);
} else {
new_node = kernel_graph->NewCNode(cnode);
new_node = NewCNode(cnode, kernel_graph);
}
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_inputs(new_inputs);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -68,7 +68,7 @@ const AnfNodePtr InsertPlaceholderForDynamicRNN::Process(const FuncGraphPtr &fun
if (kernel_graph == nullptr) {
new_node = std::make_shared<CNode>(*cnode);
} else {
new_node = kernel_graph->NewCNode(cnode);
new_node = NewCNode(cnode, kernel_graph);
}
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_inputs(new_inputs);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,15 +15,11 @@
*/
#include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h"
#include <string>
#include <tuple>
#include <utility>
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore::opt {
namespace {
using OutputInfo =
std::tuple<std::vector<TypeId>, std::vector<std::vector<size_t>>, std::vector<std::string>, std::vector<TypeId>>;
OutputInfo GetNodeOutputInfo(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
std::vector<TypeId> output_infer_dtype;
@ -87,9 +83,12 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AnfNodePtr &concat, con
builder.SetOutputsDeviceType(outputs_device_type);
return builder.Build();
}
} // namespace
AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const OutputInfo &output_info,
const std::vector<AnfNodePtr> &new_tuple_getitems, int64_t rank_size) {
AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const OutputInfo &output_info,
const std::vector<AnfNodePtr> &new_tuple_getitems,
int64_t rank_size) const {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
@ -98,7 +97,7 @@ AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePt
for (size_t j = 0, idx = i; j < LongToSize(rank_size); ++j, idx += inputs_size) {
concat_inputs.push_back(new_tuple_getitems[idx]);
}
auto concat = func_graph->NewCNode(concat_inputs);
auto concat = NewCNode(concat_inputs, func_graph);
MS_EXCEPTION_IF_NULL(concat);
MS_EXCEPTION_IF_NULL(new_tuple_getitems[i]);
const std::vector<TypeId> &dtypes = {std::get<0>(output_info)[i]};
@ -122,7 +121,6 @@ AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePt
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
return make_tuple;
}
} // namespace
const BaseRef ConcatOutputsForAllGather::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,11 +18,16 @@
#include <memory>
#include <vector>
#include <string>
#include <tuple>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ascend_helper.h"
namespace mindspore {
namespace opt {
using OutputInfo =
std::tuple<std::vector<TypeId>, std::vector<std::vector<size_t>>, std::vector<std::string>, std::vector<TypeId>>;
class ConcatOutputsForAllGather : public PatternProcessPass {
public:
explicit ConcatOutputsForAllGather(bool multigraph = true)
@ -33,6 +38,9 @@ class ConcatOutputsForAllGather : public PatternProcessPass {
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const OutputInfo &output_info, const std::vector<AnfNodePtr> &new_tuple_getitems,
int64_t rank_size) const;
KernelSelectPtr kernel_select_;
};
} // namespace opt

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,7 +15,6 @@
*/
#include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h"
#include <vector>
#include <memory>
#include "backend/optimizer/common/helper.h"
#include "backend/session/anf_runtime_algorithm.h"
@ -34,14 +33,15 @@ const BaseRef InsertPadForNMSWithMask::DefinePattern() const {
return VectorRef({prim::kPrimNMSWithMask, Xs});
}
AnfNodePtr InsertPadToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const TypeId &origin_type,
const std::vector<size_t> &origin_shape) {
AnfNodePtr InsertPadForNMSWithMask::InsertPadToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
const TypeId &origin_type,
const std::vector<size_t> &origin_shape) const {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> new_pad_inputs;
auto prim = std::make_shared<Primitive>(prim::kPrimPad->name());
new_pad_inputs.push_back(NewValueNode(prim));
new_pad_inputs.push_back(input);
CNodePtr pad = func_graph->NewCNode(new_pad_inputs);
CNodePtr pad = NewCNode(new_pad_inputs, func_graph);
MS_EXCEPTION_IF_NULL(pad);
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, pad.get());
return pad;
@ -81,7 +81,7 @@ const AnfNodePtr InsertPadForNMSWithMask::Process(const FuncGraphPtr &func_graph
if (kernel_graph == nullptr) {
new_node = std::make_shared<CNode>(*cnode);
} else {
new_node = kernel_graph->NewCNode(cnode);
new_node = NewCNode(cnode, kernel_graph);
}
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_inputs(new_inputs);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H
#include <vector>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/pass.h"
@ -28,6 +29,10 @@ class InsertPadForNMSWithMask : public PatternProcessPass {
~InsertPadForNMSWithMask() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr InsertPadToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const TypeId &origin_type,
const std::vector<size_t> &origin_shape) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -29,7 +29,7 @@ std::vector<AnfNodePtr> SplitInputsForReduceScatter::InsertSplitForInput(const F
for (size_t i = 0; i < inputs_size; i++) {
std::vector<AnfNodePtr> split_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
split_inputs.push_back(AnfAlgo::GetInputNode(node, i));
auto split = func_graph->NewCNode(split_inputs);
auto split = NewCNode(split_inputs, func_graph);
MS_EXCEPTION_IF_NULL(split);
std::vector<TypeId> dtypes(rank_size, AnfAlgo::GetPrevNodeOutputInferDataType(node, i));
std::vector<std::vector<size_t>> shapes;
@ -68,7 +68,7 @@ AnfNodePtr SplitInputsForReduceScatter::RearrangeInputsForReduceScatter(const Fu
reduce_scatter_inputs.push_back(inputs[idx]);
}
}
auto reduce_scatter = func_graph->NewCNode(reduce_scatter_inputs);
auto reduce_scatter = NewCNode(reduce_scatter_inputs, func_graph);
MS_EXCEPTION_IF_NULL(reduce_scatter);
reduce_scatter->set_abstract(node->abstract());
AnfAlgo::CopyNodeAttrs(node, reduce_scatter);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -162,6 +162,7 @@ CNodePtr DealRefAndSpiltUnSupportedTransdata::MakeDependency(const CNodePtr &get
}
return func_graph->NewCNode(depend_nodes);
}
CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefForMultipleOutput(
const FuncGraphPtr &func_graph, const CNodePtr &orig_cnode, const std::shared_ptr<kernel::OpInfo> &op_info) const {
MS_EXCEPTION_IF_NULL(func_graph);
@ -172,7 +173,7 @@ CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefForMultipleOutput(
if (!update_states.empty()) {
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
cnode = kernel_graph->NewCNode(orig_cnode);
cnode = NewCNode(orig_cnode, kernel_graph);
MS_EXCEPTION_IF_NULL(cnode);
cnode->set_inputs(orig_cnode->inputs());
for (auto &update_state : update_states) {

View File

@ -36,9 +36,24 @@
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kType32Len = 4;
constexpr size_t kType64Len = 8;
void UpdateDumpFlag(const AnfNodePtr &node, const std::vector<AnfNodePtr> &orig_nodes) {
for (auto &orig_node : orig_nodes) {
if (!orig_node->isa<CNode>()) {
continue;
}
auto orig_cnode = orig_node->cast<CNodePtr>();
if (AnfAlgo::HasNodeAttr(kAttrDump, orig_cnode)) {
AnfAlgo::CopyNodeAttr(kAttrDump, orig_cnode, node);
break;
}
}
}
} // namespace
std::vector<int64_t> Convert2Int(const std::vector<size_t> &v) {
std::vector<int64_t> result;
(void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToInt);
@ -101,6 +116,21 @@ bool UnVisited(const BaseRef &n) {
return false;
}
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg,
const std::vector<AnfNodePtr> &orig_nodes) {
MS_EXCEPTION_IF_NULL(fg);
auto node = fg->NewCNode(inputs);
UpdateDumpFlag(node, orig_nodes);
return node;
}
CNodePtr NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg, const std::vector<AnfNodePtr> &orig_nodes) {
MS_EXCEPTION_IF_NULL(fg);
auto node = fg->NewCNode(cnode);
UpdateDumpFlag(node, orig_nodes);
return node;
}
CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, size_t input_size) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
@ -654,7 +684,7 @@ bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
}
namespace {
ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp, PrimitiveVarMap *primitive_vars) {
if (utils::isa<int>(sexp)) {
return NewValueNode(utils::cast<int>(sexp));
}
@ -668,7 +698,16 @@ ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
return NewValueNode(utils::cast<bool>(sexp));
}
if (utils::isa<ValuePtr>(sexp)) {
return NewValueNode(utils::cast<ValuePtr>(sexp));
auto value = utils::cast<ValuePtr>(sexp);
if (utils::isa<PrimitivePtr>(sexp)) {
auto prim = utils::cast<PrimitivePtr>(sexp);
if (primitive_vars->find(prim) != primitive_vars->end()) {
prim = std::make_shared<Primitive>(prim->name());
value = prim;
}
(*primitive_vars)[prim] = std::make_shared<Var>(prim);
}
return NewValueNode(value);
}
return nullptr;
}
@ -831,7 +870,7 @@ AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap
if (utils::isa<AnfNodePtr>(sexp)) {
return utils::cast<AnfNodePtr>(sexp);
}
auto value_node = CreateValueNodeWithSexp(sexp);
auto value_node = CreateValueNodeWithSexp(sexp, primitive_vars);
if (value_node == nullptr) {
MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString();
}

View File

@ -137,6 +137,11 @@ bool UnVisited(const BaseRef &n);
bool Visited(const BaseRef &n);
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg,
const std::vector<AnfNodePtr> &orig_nodes);
CNodePtr NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg, const std::vector<AnfNodePtr> &orig_nodes);
// check if the input node is CNode, then check it's input_size, return CNodePtr if check success.
CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, size_t input_size);

View File

@ -31,7 +31,8 @@ PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph)
: NodePass(name),
multigraph_(multigraph),
pattern_engine_(PatternEngine(std::make_shared<Visitor>())),
primitive_vars_(std::make_shared<PrimitiveVarMap>()) {}
primitive_vars_(std::make_shared<PrimitiveVarMap>()),
equiv_(std::make_shared<Equiv>()) {}
const BaseRef PatternProcessPass::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
@ -50,9 +51,10 @@ AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNode
auto primitive = GetCNodePrimitive(pattern_);
if (IsPrimitiveCNode(node, primitive)) {
auto empty_equiv = std::make_shared<Equiv>();
MS_EXCEPTION_IF_NULL(primitive_vars_);
EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, empty_equiv);
MS_EXCEPTION_IF_NULL(equiv_);
equiv_->clear();
EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, equiv_);
if (equiv != nullptr && !equiv->empty()) {
return Process(func_graph, node, equiv);
}
@ -60,21 +62,62 @@ AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNode
return nullptr;
}
std::vector<AnfNodePtr> PatternProcessPass::GetOrigNodes() const {
std::vector<AnfNodePtr> orig_nodes;
for (auto &prim_var : *primitive_vars_) {
if (equiv_->find(prim_var.second) == equiv_->end()) {
continue;
}
auto baseref = (*equiv_)[prim_var.second];
if (!utils::isa<CNode>(baseref)) {
continue;
}
auto node = utils::cast<AnfNodePtr>(baseref);
orig_nodes.push_back(node);
}
return orig_nodes;
}
CNodePtr PatternProcessPass::NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) const {
MS_EXCEPTION_IF_NULL(fg);
auto orig_nodes = GetOrigNodes();
return opt::NewCNode(inputs, fg, orig_nodes);
}
CNodePtr PatternProcessPass::NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg) const {
MS_EXCEPTION_IF_NULL(fg);
auto orig_nodes = GetOrigNodes();
return opt::NewCNode(cnode, fg, orig_nodes);
}
bool MultipleOutputPatternProcessPass::MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
VarPtr fg = std::make_shared<Var>("RootG");
auto empty_equiv = std::make_shared<Equiv>();
MS_EXCEPTION_IF_NULL(child_primitive_vars_);
MS_EXCEPTION_IF_NULL(child_equiv_);
EquivPtr another_equiv =
child_pattern_engine_.Match(SexpToNode(DefineAnotherPattern(), fg, child_primitive_vars_.get(), true), node,
*child_primitive_vars_, empty_equiv);
*child_primitive_vars_, child_equiv_);
if (another_equiv != nullptr && !another_equiv->empty()) {
return IsShareNodes(equiv, another_equiv);
}
return false;
}
std::vector<AnfNodePtr> MultipleOutputPatternProcessPass::GetOrigNodes() const {
std::vector<AnfNodePtr> orig_nodes = PatternProcessPass::GetOrigNodes();
for (auto &prim_var : *child_primitive_vars_) {
auto baseref = (*child_equiv_)[prim_var.second];
if (!utils::isa<CNode>(baseref)) {
continue;
}
auto node = utils::cast<AnfNodePtr>(baseref);
orig_nodes.push_back(node);
}
return orig_nodes;
}
void GraphOptimizer::AddPassManager(const PassManagerPtr &pass_manager) {
if (pass_manager != nullptr) {
pass_managers_.push_back(pass_manager);

View File

@ -41,6 +41,11 @@ class PatternProcessPass : public NodePass {
virtual const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const = 0;
virtual const BaseRef DefinePattern() const;
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) const;
CNodePtr NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg) const;
protected:
virtual std::vector<AnfNodePtr> GetOrigNodes() const;
private:
void Build();
@ -49,6 +54,7 @@ class PatternProcessPass : public NodePass {
bool multigraph_ = true;
PatternEngine pattern_engine_;
PrimitiveVarMapPtr primitive_vars_;
EquivPtr equiv_;
};
class MultipleOutputPatternProcessPass : public PatternProcessPass {
@ -56,7 +62,8 @@ class MultipleOutputPatternProcessPass : public PatternProcessPass {
explicit MultipleOutputPatternProcessPass(const std::string &name = "", bool multigraph = true)
: PatternProcessPass(name, multigraph),
child_pattern_engine_(PatternEngine(std::make_shared<Visitor>())),
child_primitive_vars_(std::make_shared<PrimitiveVarMap>()) {}
child_primitive_vars_(std::make_shared<PrimitiveVarMap>()),
child_equiv_(std::make_shared<Equiv>()) {}
~MultipleOutputPatternProcessPass() override = default;
virtual BaseRef DefineAnotherPattern() const = 0;
// check two patterns whether share the same nodes or not
@ -64,8 +71,10 @@ class MultipleOutputPatternProcessPass : public PatternProcessPass {
protected:
bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const;
std::vector<AnfNodePtr> GetOrigNodes() const override;
PatternEngine child_pattern_engine_;
PrimitiveVarMapPtr child_primitive_vars_;
EquivPtr child_equiv_;
};
class GraphOptimizer {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -25,6 +25,7 @@
#include "runtime/device/kernel_info.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/kernel_build_info.h"
#include "backend/optimizer/common/helper.h"
#include "frontend/parallel/context.h"
namespace mindspore {
@ -318,16 +319,18 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr
if (end_index >= communication_op_info.communication_op_nodes.size()) {
MS_LOG(EXCEPTION) << "end index out of communication_op_nodes size";
}
std::vector<AnfNodePtr> orig_nodes;
for (size_t idx = start_index; idx <= end_index; ++idx) {
auto cnode = communication_op_info.communication_op_nodes[idx];
MS_EXCEPTION_IF_NULL(cnode);
if (idx != start_index) {
AdjustAllReduceInputWithLoad(cnode);
}
fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
(void)fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
(void)orig_nodes.emplace_back(cnode);
}
CheckInputs(fusion_inputs);
AnfNodePtr fused_node = func_graph->NewCNode(fusion_inputs);
AnfNodePtr fused_node = NewCNode(fusion_inputs, func_graph, orig_nodes);
MS_EXCEPTION_IF_NULL(fused_node);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(kernel_info);
@ -450,7 +453,7 @@ bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
const float input_grad_size_num = 0.0;
const float input_grad_time_num = 0.0;
// divide candidate fusion groups with same (group,op,fusion) attrs, fusion==0 means not fusion
// divide candidate fusion groups with same (group,op,fusion,dtype) attrs, fusion==0 means not fusion
std::unordered_map<std::string, CommunicationOpInfo> candidate_groups;
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -58,8 +58,10 @@ AnfNodePtr CreateTensorInput(const KernelGraphPtr &kernel_graph, const AnfNodePt
tensor_input->set_scope(input_node->scope());
return tensor_input;
}
} // namespace
AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
AnfNodePtr ConvertConstInputToTensorInput::ConstInputToTensorInput(const FuncGraphPtr &func_graph,
const CNodePtr &cnode) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
const std::set<std::string> no_need_to_convert_nodes = {kStackOpName};
@ -89,7 +91,7 @@ AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePt
}
if (need_update) {
MS_EXCEPTION_IF_NULL(func_graph);
auto new_cnode = func_graph->NewCNode(new_inputs);
auto new_cnode = NewCNode(new_inputs, func_graph);
MS_EXCEPTION_IF_NULL(new_cnode);
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimDepend)) {
new_cnode->set_abstract(new_inputs[1]->abstract());
@ -105,7 +107,6 @@ AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePt
}
return nullptr;
}
} // namespace
const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -29,6 +29,9 @@ class ConvertConstInputToTensorInput : public PatternProcessPass {
: PatternProcessPass("convert_const_input_to_tensor_input", multigraph) {}
~ConvertConstInputToTensorInput() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -81,7 +81,7 @@ const AnfNodePtr ConvertConstScalarToTensor::Process(const FuncGraphPtr &func_gr
if (kernel_graph == nullptr || !input_changed) {
return nullptr;
}
return kernel_graph->NewCNode(cnode);
return NewCNode(cnode, kernel_graph);
}
} // namespace opt
} // namespace mindspore

View File

@ -88,7 +88,7 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
if (kernel_graph == nullptr || !cnode_input_changed) {
return nullptr;
}
return kernel_graph->NewCNode(cnode);
return NewCNode(cnode, kernel_graph);
}
} // namespace opt
} // namespace mindspore

View File

@ -26,7 +26,6 @@
namespace mindspore {
namespace opt {
using KernelWithIndex = std::pair<CNodePtr, size_t>;
namespace {
CNodePtr GetRealPrevCNode(const AnfNodePtr &node, size_t index, std::vector<KernelWithIndex> *pass_vector) {
MS_EXCEPTION_IF_NULL(pass_vector);
@ -78,9 +77,11 @@ bool TransDataOpEliminateCondition(const CNodePtr &node1, const CNodePtr &node2)
AnfAlgo::GetOutputFormat(node1, 0) == AnfAlgo::GetInputFormat(node2, 0) &&
kernel::IsSameShape(AnfAlgo::GetInputDeviceShape(node2, 0), AnfAlgo::GetOutputDeviceShape(node1, 0));
}
} // namespace
const AnfNodePtr ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const CNodePtr &prev_cnode,
std::vector<KernelWithIndex> *pass_vector) {
const AnfNodePtr EliminateRedundantOp::ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const CNodePtr &prev_cnode,
std::vector<KernelWithIndex> *pass_vector) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(pass_vector);
FuncGraphManagerPtr manager = func_graph->manager();
@ -113,7 +114,7 @@ const AnfNodePtr ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNode
MS_LOG(ERROR) << "pass_size should >= 2";
}
for (size_t idx = pass_size - kOffset; idx > 0; --idx) {
auto new_node = func_graph->NewCNode((*pass_vector)[idx].first->inputs());
auto new_node = NewCNode((*pass_vector)[idx].first->inputs(), func_graph);
if (idx == pass_size - kOffset) {
new_node->set_input((*pass_vector)[idx].second,
(*pass_vector)[idx + 1].first->input((*pass_vector)[idx + 1].second));
@ -125,7 +126,6 @@ const AnfNodePtr ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNode
return (*pass_vector)[1].first;
}
}
} // namespace
void EliminateRedundantOp::Init() {
(void)redundant_process_map_.emplace(std::pair<std::string, RedundantOpPair>(

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -29,6 +29,7 @@ namespace mindspore {
namespace opt {
using ConditionFunc = std::function<bool(const CNodePtr &node1, const CNodePtr &node2)>;
using RedundantOpPair = std::pair<std::string, ConditionFunc>;
using KernelWithIndex = std::pair<CNodePtr, size_t>;
class EliminateRedundantOp : public PatternProcessPass {
public:
@ -41,6 +42,8 @@ class EliminateRedundantOp : public PatternProcessPass {
private:
void Init();
const AnfNodePtr DoEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const;
const AnfNodePtr ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const CNodePtr &prev_cnode, std::vector<KernelWithIndex> *pass_vector) const;
std::unordered_map<std::string, RedundantOpPair> redundant_process_map_;
};
} // namespace opt

View File

@ -28,46 +28,8 @@
namespace mindspore {
namespace opt {
const int axis_input_index = 2;
namespace {
AnfNodePtr NewRankOp(const AnfNodePtr &cnode, const KernelGraphPtr &kernel_graph) {
std::vector<AnfNodePtr> rank_inputs;
auto prim = std::make_shared<Primitive>(prim::kPrimRank->name());
rank_inputs.push_back(NewValueNode(prim));
auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, 1);
rank_inputs.push_back(prev_node.first);
auto rank_op = kernel_graph->NewCNode(rank_inputs);
MS_EXCEPTION_IF_NULL(rank_op);
rank_op->set_abstract(prev_node.first->abstract());
return rank_op;
}
AnfNodePtr NewRangeOp(const AnfNodePtr &rank_op, const KernelGraphPtr &kernel_graph) {
std::vector<AnfNodePtr> range_inputs;
auto prim = std::make_shared<Primitive>(prim::kPrimRange->name());
range_inputs.push_back(NewValueNode(prim));
// "start"
auto start_ = NewValueNode(SizeToLong(0));
MS_EXCEPTION_IF_NULL(start_);
auto imm_start = std::make_shared<Int64Imm>(SizeToLong(0));
start_->set_abstract(std::make_shared<abstract::AbstractScalar>(imm_start));
range_inputs.push_back(start_);
// "limit"
range_inputs.push_back(rank_op);
// "delta"
auto delta_ = NewValueNode(SizeToLong(1));
MS_EXCEPTION_IF_NULL(delta_);
auto imm_delta = std::make_shared<Int64Imm>(SizeToLong(1));
delta_->set_abstract(std::make_shared<abstract::AbstractScalar>(imm_delta));
range_inputs.push_back(delta_);
// new range op
auto range_op = kernel_graph->NewCNode(range_inputs);
MS_EXCEPTION_IF_NULL(range_op);
range_op->set_abstract(rank_op->abstract());
return range_op;
}
const int axis_input_index = 2;
bool IsNeedComputeRank(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
@ -95,8 +57,48 @@ bool IsNeedComputeRank(const CNodePtr &cnode) {
}
return false;
}
} // namespace
AnfNodePtr InsertAssistNode(const CNodePtr &cnode, const KernelGraphPtr &kernel_graph) {
AnfNodePtr ReduceSumOptimizer::NewRankOp(const AnfNodePtr &cnode, const KernelGraphPtr &kernel_graph) const {
std::vector<AnfNodePtr> rank_inputs;
auto prim = std::make_shared<Primitive>(prim::kPrimRank->name());
rank_inputs.push_back(NewValueNode(prim));
auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, 1);
rank_inputs.push_back(prev_node.first);
auto rank_op = NewCNode(rank_inputs, kernel_graph);
MS_EXCEPTION_IF_NULL(rank_op);
rank_op->set_abstract(prev_node.first->abstract());
return rank_op;
}
AnfNodePtr ReduceSumOptimizer::NewRangeOp(const AnfNodePtr &rank_op, const KernelGraphPtr &kernel_graph) const {
std::vector<AnfNodePtr> range_inputs;
auto prim = std::make_shared<Primitive>(prim::kPrimRange->name());
range_inputs.push_back(NewValueNode(prim));
// "start"
auto start_ = NewValueNode(SizeToLong(0));
MS_EXCEPTION_IF_NULL(start_);
auto imm_start = std::make_shared<Int64Imm>(SizeToLong(0));
start_->set_abstract(std::make_shared<abstract::AbstractScalar>(imm_start));
range_inputs.push_back(start_);
// "limit"
range_inputs.push_back(rank_op);
// "delta"
auto delta_ = NewValueNode(SizeToLong(1));
MS_EXCEPTION_IF_NULL(delta_);
auto imm_delta = std::make_shared<Int64Imm>(SizeToLong(1));
delta_->set_abstract(std::make_shared<abstract::AbstractScalar>(imm_delta));
range_inputs.push_back(delta_);
// new range op
auto range_op = NewCNode(range_inputs, kernel_graph);
MS_EXCEPTION_IF_NULL(range_op);
range_op->set_abstract(rank_op->abstract());
return range_op;
}
AnfNodePtr ReduceSumOptimizer::InsertAssistNode(const CNodePtr &cnode, const KernelGraphPtr &kernel_graph) const {
// the input dim is unknown, need rank + range, can not supported now; ;
MS_LOG(EXCEPTION)
<< "Can not support the case that input is dim unknown and axis is empty or axis contain value less 0. node: "
@ -108,7 +110,7 @@ AnfNodePtr InsertAssistNode(const CNodePtr &cnode, const KernelGraphPtr &kernel_
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
new_inputs.push_back(cnode->input(1));
new_inputs.push_back(range_op);
auto new_node = kernel_graph->NewCNode(cnode);
auto new_node = NewCNode(cnode, kernel_graph);
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_inputs(new_inputs);
return new_node;
@ -121,7 +123,7 @@ AnfNodePtr InsertAssistNode(const CNodePtr &cnode, const KernelGraphPtr &kernel_
// 2: the value of axis_input contain the value less 0,
// the new tensor of the new value node should be "shape.size() + the_old_value_less_0",
// the shape is the first input'shape of ReduceSum;
AnfNodePtr NewAssistValueNode(const CNodePtr &cnode, const KernelGraphPtr &kernel_graph) {
AnfNodePtr ReduceSumOptimizer::NewAssistValueNode(const CNodePtr &cnode, const KernelGraphPtr &kernel_graph) const {
// axis is a tuple ,maybe empty or contain a value less 0;
auto axis_input = cnode->input(axis_input_index);
if (IsValueNode<ValueTuple>(axis_input)) {
@ -159,7 +161,7 @@ AnfNodePtr NewAssistValueNode(const CNodePtr &cnode, const KernelGraphPtr &kerne
assist_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kInt64, axes_value));
auto assist_value_node = kernel_graph->NewValueNode(assist_node);
new_inputs.push_back(assist_value_node);
auto new_node = kernel_graph->NewCNode(cnode);
auto new_node = NewCNode(cnode, kernel_graph);
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_inputs(new_inputs);
return new_node;
@ -167,7 +169,6 @@ AnfNodePtr NewAssistValueNode(const CNodePtr &cnode, const KernelGraphPtr &kerne
}
return nullptr;
}
} // namespace
const AnfNodePtr ReduceSumOptimizer::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {

View File

@ -26,6 +26,12 @@ class ReduceSumOptimizer : public PatternProcessPass {
~ReduceSumOptimizer() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr NewRankOp(const AnfNodePtr &cnode, const KernelGraphPtr &kernel_graph) const;
AnfNodePtr NewRangeOp(const AnfNodePtr &rank_op, const KernelGraphPtr &kernel_graph) const;
AnfNodePtr InsertAssistNode(const CNodePtr &cnode, const KernelGraphPtr &kernel_graph) const;
AnfNodePtr NewAssistValueNode(const CNodePtr &cnode, const KernelGraphPtr &kernel_graph) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -2001,6 +2001,15 @@ bool AnfRuntimeAlgorithm::GetBooleanAttr(const AnfNodePtr &node, const std::stri
return AnfAlgo::GetNodeAttr<bool>(node, attr);
}
std::optional<string> AnfRuntimeAlgorithm::GetDumpFlag(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr || !AnfAlgo::HasNodeAttr(kAttrDump, cnode)) {
return std::optional<string>{};
}
return std::optional<string>{AnfAlgo::GetNodeAttr<string>(node, kAttrDump)};
}
bool AnfRuntimeAlgorithm::HasDynamicShapeFlag(const PrimitivePtr &prim) {
auto get_bool_attr = [](const PrimitivePtr &primitive, const std::string &attr_name) -> bool {
MS_EXCEPTION_IF_NULL(primitive);

View File

@ -25,6 +25,7 @@
#include <memory>
#include <unordered_set>
#include <map>
#include <optional>
#include "ir/anf.h"
#include "ir/dtype.h"
#include "base/base.h"
@ -288,6 +289,7 @@ class AnfRuntimeAlgorithm {
static bool IsCondControlKernel(const CNodePtr &node);
static bool IsIndependentNode(const CNodePtr &node);
static bool GetBooleanAttr(const AnfNodePtr &node, const std::string &attr);
static std::optional<string> GetDumpFlag(const AnfNodePtr &node);
static void GetRealDynamicShape(const std::vector<size_t> &shape, NotNull<std::vector<int64_t> *> dynamic_shape);
static std::vector<int64_t> GetInputMaxShape(const AnfNodePtr &anf_node, size_t index);
static std::vector<int64_t> GetInputMinShape(const AnfNodePtr &anf_node, size_t index);

View File

@ -483,6 +483,7 @@ constexpr auto kAttrProfilingIterEnd = "PROFILING_ITER_END";
constexpr auto kAttrHiddenSize = "hidden_size";
constexpr auto kAttrInputSize = "input_size";
constexpr auto kAttrDstType = "dst_type";
constexpr auto kAttrDump = "dump";
constexpr auto kAttrSkipNopOpAddr = "skip_nop_op_addr";
constexpr auto kAttrFuncType = "func_type";

View File

@ -1218,5 +1218,16 @@ int GetDataTypeFromAnfNode(const AnfNodePtr &anf_node, TypeId *type_id) {
*type_id = type_ptr->type_id();
return RET_OK;
}
// not implement for lite, just for api compatible
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg,
const std::vector<AnfNodePtr> &orig_nodes) {
return fg->NewCNode(inputs);
}
// not implement for lite, just for api compatible
CNodePtr NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg, const std::vector<AnfNodePtr> &orig_nodes) {
return nullptr;
}
} // namespace opt
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,14 +16,18 @@
#include <iostream>
#include <sstream>
#include <vector>
#include <memory>
#include <algorithm>
#include "common/common_test.h"
#include "backend/optimizer/common/pattern_engine.h"
#include "backend/optimizer/common/visit.h"
#include "backend/optimizer/common/helper.h"
#include "base/base_ref.h"
#include "base/core_ops.h"
#include "ir/anf.h"
#include "utils/utils.h"
namespace mindspore {
using PatternListType = std::initializer_list<BaseRef>;
@ -32,15 +36,16 @@ bool Equal(const BaseRef &a, const BaseRef &b) { return a == b; }
class TestMatchEngine : public UT::Common {
public:
TestMatchEngine()
: TU(std::make_shared<Visitor>()) {
TestMatchEngine() : TU(std::make_shared<Visitor>()) {
equiv_null = std::make_shared<Equiv>();
fg = std::make_shared<FuncGraph>();
};
public:
PatternEngine TU;
EquivPtr equiv_null;
PrimitiveVarMap primitive_vars_null;
FuncGraphPtr fg;
};
TEST_F(TestMatchEngine, Var) {
@ -215,4 +220,79 @@ TEST_F(TestMatchEngine, Match_CondVar) {
equiv_null);
ASSERT_EQ(d, nullptr);
}
/// Feature: Backend support dump flag
/// Description: PatternEngine match var with primitive
/// Expectation: Get correct Equiv map
TEST_F(TestMatchEngine, Match_PrimVar) {
VarPtr mul1 = std::make_shared<Var>(std::make_shared<Primitive>(kMulOpName));
VarPtr mul2 = std::make_shared<Var>(std::make_shared<Primitive>(kMulOpName));
VarPtr v1 = std::make_shared<Var>();
VarPtr sv2 = std::make_shared<SeqVar>();
auto pattern_ref = VectorRef({mul1, v1, VectorRef({mul2, sv2})});
PrimitiveVarMapPtr primitive_vars = std::make_shared<PrimitiveVarMap>();
auto pattern_node = opt::SexpToNode(pattern_ref, fg, primitive_vars.get(), true);
ASSERT_EQ(primitive_vars->size(), std::size_t(2));
auto anode1 = std::make_shared<AnfNode>(fg);
auto anode2 = std::make_shared<AnfNode>(fg);
auto anode3 = std::make_shared<AnfNode>(fg);
AnfNodePtr mul2_cnode = std::make_shared<CNode>(
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), anode2, anode3}, fg);
AnfNodePtr mul1_cnode = std::make_shared<CNode>(
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), anode1, mul2_cnode}, fg);
EquivPtr d;
equiv_null->clear();
d = TU.Match(pattern_node, mul1_cnode, *primitive_vars, equiv_null);
ASSERT_EQ(d->size(), std::size_t(4));
ASSERT_EQ((*d)[mul2], mul2_cnode);
ASSERT_EQ((*d)[mul1], mul1_cnode);
AnfNodePtr sub_cnode = std::make_shared<CNode>(
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kSubOpName)), anode1, mul2_cnode}, fg);
equiv_null->clear();
d = TU.Match(pattern_node, sub_cnode, *primitive_vars, equiv_null);
ASSERT_EQ(d, nullptr);
}
/// Feature: Backend support dump flag
/// Description: PatternEngine match primitive
/// Expectation: Get correct Equiv map
TEST_F(TestMatchEngine, Match_Prim) {
VarPtr v1 = std::make_shared<Var>();
VarPtr sv2 = std::make_shared<SeqVar>();
auto pattern_ref = VectorRef({prim::kPrimMul, v1, VectorRef({prim::kPrimMul, sv2})});
PrimitiveVarMapPtr primitive_vars = std::make_shared<PrimitiveVarMap>();
auto pattern_node = opt::SexpToNode(pattern_ref, fg, primitive_vars.get(), true);
ASSERT_EQ(primitive_vars->size(), std::size_t(2));
auto anode1 = std::make_shared<AnfNode>(fg);
auto anode2 = std::make_shared<AnfNode>(fg);
auto anode3 = std::make_shared<AnfNode>(fg);
AnfNodePtr mul2_cnode = std::make_shared<CNode>(
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), anode2, anode3}, fg);
AnfNodePtr mul1_cnode = std::make_shared<CNode>(
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), anode1, mul2_cnode}, fg);
EquivPtr d;
equiv_null->clear();
d = TU.Match(pattern_node, mul1_cnode, *primitive_vars, equiv_null);
ASSERT_EQ(d->size(), std::size_t(4));
for (auto &prim_var : *primitive_vars) {
if (prim_var.first == prim::kPrimMul) {
ASSERT_EQ((*d)[prim_var.second], mul1_cnode);
} else {
ASSERT_EQ((*d)[prim_var.second], mul2_cnode);
}
}
AnfNodePtr sub_cnode = std::make_shared<CNode>(
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kSubOpName)), anode1, mul2_cnode}, fg);
equiv_null->clear();
d = TU.Match(pattern_node, sub_cnode, *primitive_vars, equiv_null);
ASSERT_EQ(d, nullptr);
}
} // namespace mindspore

View File

@ -0,0 +1,84 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <sstream>
#include <vector>
#include <memory>
#include <algorithm>
#include "common/common_test.h"
#define private public
#define protected public
#include "backend/optimizer/common/optimizer.h"
#undef private
#undef protected
#include "base/core_ops.h"
#include "ir/anf.h"
#include "ir/value.h"
#include "utils/utils.h"
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace opt {
class TestPass : public PatternProcessPass {
public:
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const { return nullptr; };
};
class TestPatternProcessPass : public UT::Common {
public:
TestPatternProcessPass() : TU() { fg = std::make_shared<FuncGraph>(); };
public:
TestPass TU;
FuncGraphPtr fg;
};
/// Feature: Backend support dump flag
/// Description: Get orig nodes according to primitive_vars_ and equiv_
/// Expectation: Get correct orig nodes
TEST_F(TestPatternProcessPass, test_GetOrigNodes) {
TU.primitive_vars_->clear();
TU.equiv_->clear();
VarPtr mul1 = std::make_shared<Var>(std::make_shared<Primitive>(kMulOpName));
VarPtr v1 = std::make_shared<Var>();
VarPtr v2 = std::make_shared<Var>();
(*TU.primitive_vars_)[mul1->primitive()] = mul1;
auto mul1_node = std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(prim::kPrimMul)}, fg);
auto anode1 = std::make_shared<AnfNode>(fg);
auto anode2 = std::make_shared<AnfNode>(fg);
(*TU.equiv_)[mul1] = mul1_node;
(*TU.equiv_)[v1] = anode1;
(*TU.equiv_)[v2] = anode2;
auto orig_nodes = TU.GetOrigNodes();
ASSERT_EQ(orig_nodes.size(), std::size_t(1));
ASSERT_EQ(orig_nodes[0], mul1_node);
VarPtr mul2 = std::make_shared<Var>(std::make_shared<Primitive>(kMulOpName));
(*TU.primitive_vars_)[mul2->primitive()] = mul2;
orig_nodes = TU.GetOrigNodes();
ASSERT_EQ(orig_nodes.size(), std::size_t(1));
auto mul2_node = std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(prim::kPrimMul)}, fg);
(*TU.equiv_)[mul2] = mul2_node;
orig_nodes = TU.GetOrigNodes();
ASSERT_EQ(orig_nodes.size(), std::size_t(2));
}
} // namespace opt
} // namespace mindspore