!17614 pass optimizer
From: @zhupuxu Reviewed-by: @zhoufeng54 Signed-off-by:
This commit is contained in:
commit
86cb781ce0
|
@ -32,9 +32,7 @@ namespace opt {
|
|||
PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph)
|
||||
: NodePass(name),
|
||||
multigraph_(multigraph),
|
||||
pattern_engine_(PatternEngine(std::make_shared<DefaultVisitor>(),
|
||||
std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual),
|
||||
std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))),
|
||||
pattern_engine_(PatternEngine(std::make_shared<Visitor>())),
|
||||
primitive_vars_(std::make_shared<PrimitiveVarMap>()) {}
|
||||
|
||||
const BaseRef PatternProcessPass::DefinePattern() const {
|
||||
|
@ -53,11 +51,14 @@ AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNode
|
|||
Build();
|
||||
}
|
||||
|
||||
auto empty_equiv = std::make_shared<Equiv>();
|
||||
MS_EXCEPTION_IF_NULL(primitive_vars_);
|
||||
EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, empty_equiv);
|
||||
if (equiv != nullptr && !equiv->empty()) {
|
||||
return Process(func_graph, node, equiv);
|
||||
auto primitive = GetCNodePrimitive(pattern_);
|
||||
if (IsPrimitiveCNode(node, primitive)) {
|
||||
auto empty_equiv = std::make_shared<Equiv>();
|
||||
MS_EXCEPTION_IF_NULL(primitive_vars_);
|
||||
EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, empty_equiv);
|
||||
if (equiv != nullptr && !equiv->empty()) {
|
||||
return Process(func_graph, node, equiv);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -55,9 +55,7 @@ class MultipleOutputPatternProcessPass : public PatternProcessPass {
|
|||
public:
|
||||
explicit MultipleOutputPatternProcessPass(const std::string &name = "", bool multigraph = true)
|
||||
: PatternProcessPass(name, multigraph),
|
||||
child_pattern_engine_(PatternEngine(std::make_shared<DefaultVisitor>(),
|
||||
std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual),
|
||||
std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))),
|
||||
child_pattern_engine_(PatternEngine(std::make_shared<Visitor>())),
|
||||
child_primitive_vars_(std::make_shared<PrimitiveVarMap>()) {}
|
||||
~MultipleOutputPatternProcessPass() override = default;
|
||||
virtual BaseRef DefineAnotherPattern() const = 0;
|
||||
|
|
|
@ -24,6 +24,9 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class CacheManager;
|
||||
using CacheManagerPtr = std::shared_ptr<CacheManager>;
|
||||
|
||||
// @brief ANF Graph level optimization base pass
|
||||
class Pass {
|
||||
public:
|
||||
|
@ -31,9 +34,12 @@ class Pass {
|
|||
virtual ~Pass() = default;
|
||||
virtual bool Run(const FuncGraphPtr &func_graph) = 0;
|
||||
virtual std::string name() const { return name_; }
|
||||
void SetCacheManager(const CacheManagerPtr &cm) { cache_manager_ = cm; }
|
||||
const CacheManagerPtr &GetCacheManager() const { return cache_manager_; }
|
||||
|
||||
private:
|
||||
const std::string name_;
|
||||
CacheManagerPtr cache_manager_;
|
||||
};
|
||||
using PassPtr = std::shared_ptr<Pass>;
|
||||
} // namespace opt
|
||||
|
|
|
@ -23,9 +23,73 @@
|
|||
#include "ir/manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
||||
void CacheManager::Update(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto type_iter = type_map_.find(node);
|
||||
auto shape_iter = shape_map_.find(node);
|
||||
if (type_iter != type_map_.end()) {
|
||||
type_map_.erase(type_iter);
|
||||
}
|
||||
if (shape_iter != shape_map_.end()) {
|
||||
shape_map_.erase(shape_iter);
|
||||
}
|
||||
}
|
||||
|
||||
TypeId CacheManager::GetOutputType(const AnfNodePtr &node, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto iter = type_map_.find(node);
|
||||
if (iter != type_map_.end()) {
|
||||
auto types = iter->second;
|
||||
auto type_iter = types.find(index);
|
||||
if (type_iter != types.end()) {
|
||||
return type_iter->second;
|
||||
}
|
||||
return kTypeUnknown;
|
||||
}
|
||||
auto output_nums = AnfAlgo::GetOutputTensorNum(node);
|
||||
std::map<size_t, TypeId> index_to_types;
|
||||
TypeId result = kTypeUnknown;
|
||||
for (size_t i = 0; i < output_nums; i++) {
|
||||
auto output_type = AnfAlgo::GetOutputInferDataType(node, i);
|
||||
index_to_types.emplace(i, output_type);
|
||||
if (index == i) {
|
||||
result = output_type;
|
||||
}
|
||||
}
|
||||
type_map_.emplace(node, index_to_types);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<size_t> CacheManager::GetOutputShape(const AnfNodePtr &node, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto iter = shape_map_.find(node);
|
||||
if (iter != shape_map_.end()) {
|
||||
auto shapes = iter->second;
|
||||
auto shape_iter = shapes.find(index);
|
||||
if (shape_iter != shapes.end()) {
|
||||
return shape_iter->second;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
auto output_nums = AnfAlgo::GetOutputTensorNum(node);
|
||||
std::map<size_t, std::vector<size_t>> index_to_shapes;
|
||||
std::vector<size_t> result = {};
|
||||
for (size_t i = 0; i < output_nums; i++) {
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(node, i);
|
||||
index_to_shapes.emplace(i, output_shape);
|
||||
if (index == i) {
|
||||
result = output_shape;
|
||||
}
|
||||
}
|
||||
shape_map_.emplace(node, index_to_shapes);
|
||||
return result;
|
||||
}
|
||||
|
||||
const std::vector<PassPtr> &PassManager::Passes() const { return passes_; }
|
||||
|
||||
void PassManager::AddPass(const PassPtr &pass) {
|
||||
|
@ -84,6 +148,7 @@ bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr>
|
|||
size_t num = 0;
|
||||
for (const auto &pass : passes) {
|
||||
if (pass != nullptr) {
|
||||
pass->SetCacheManager(cache_manager_);
|
||||
changed = RunPass(func_graph, num, pass) || changed;
|
||||
DumpPassIR(func_graph, GetPassFullname(num, pass));
|
||||
num++;
|
||||
|
|
|
@ -20,17 +20,32 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "backend/optimizer/common/node_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class CacheManager {
|
||||
public:
|
||||
CacheManager() {}
|
||||
~CacheManager() = default;
|
||||
void Update(const AnfNodePtr &node);
|
||||
TypeId GetOutputType(const AnfNodePtr &node, size_t index);
|
||||
std::vector<size_t> GetOutputShape(const AnfNodePtr &node, size_t index);
|
||||
|
||||
private:
|
||||
std::map<AnfNodePtr, std::map<size_t, TypeId>> type_map_;
|
||||
std::map<AnfNodePtr, std::map<size_t, std::vector<size_t>>> shape_map_;
|
||||
};
|
||||
using CacheManagerPtr = std::shared_ptr<CacheManager>;
|
||||
|
||||
// @brief For optimization passes management
|
||||
class PassManager {
|
||||
public:
|
||||
explicit PassManager(const std::string &name = "pm", bool run_only_once = true)
|
||||
: name_(name), passes_{}, run_only_once_(run_only_once) {}
|
||||
: name_(name), passes_{}, run_only_once_(run_only_once), cache_manager_(std::make_shared<CacheManager>()) {}
|
||||
virtual ~PassManager() = default;
|
||||
// Get all the passes added by AddPass
|
||||
const std::vector<PassPtr> &Passes() const;
|
||||
|
@ -57,6 +72,7 @@ class PassManager {
|
|||
const std::string name_;
|
||||
std::vector<PassPtr> passes_;
|
||||
bool run_only_once_;
|
||||
CacheManagerPtr cache_manager_;
|
||||
};
|
||||
using PassManagerPtr = std::shared_ptr<PassManager>;
|
||||
} // namespace opt
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "ir/anf.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
#include "utils/overload.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
static int GetNextTag() {
|
||||
|
@ -156,33 +157,13 @@ bool PatternEngine::ToVector(const VectorRef &pattern_ref, const VectorRef &expr
|
|||
bool PatternEngine::ToVector(const BaseRef &pattern_ref, const BaseRef &expr_ref, VectorRef *const values_pattern,
|
||||
VectorRef *const values_expr) const {
|
||||
MS_EXCEPTION_IF_NULL(values_expr);
|
||||
// visitor to visite the list
|
||||
auto appender_pattern = [](VectorRef &values) {
|
||||
std::function<BaseRef(const BaseRef &)> fn = [&](const BaseRef &u) {
|
||||
values.push_back(GetVar(u));
|
||||
return u;
|
||||
};
|
||||
return fn;
|
||||
};
|
||||
|
||||
visitor_->SetFn(appender_pattern(*values_pattern));
|
||||
MS_LOG(DEBUG) << "visit pattern_ref";
|
||||
bool success = visitor_->Visit(pattern_ref, nullptr);
|
||||
bool success = visitor_->Visit(pattern_ref, values_pattern, nullptr);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto appender_expr = [](VectorRef &values) {
|
||||
std::function<BaseRef(const BaseRef &)> fn = [&](const BaseRef &u) {
|
||||
values.push_back(u);
|
||||
return u;
|
||||
};
|
||||
return fn;
|
||||
};
|
||||
|
||||
visitor_->SetFn(appender_expr(*values_expr));
|
||||
MS_LOG(DEBUG) << "visit expr_ref";
|
||||
return visitor_->Visit(expr_ref, nullptr);
|
||||
return visitor_->Visit(expr_ref, values_expr, nullptr);
|
||||
}
|
||||
|
||||
static int GetSVarStartIndex(const VectorRef &values) {
|
||||
|
@ -292,7 +273,7 @@ EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const
|
|||
}
|
||||
|
||||
// 2. check equal
|
||||
if (eq_(pattern_ref, expr_ref)) {
|
||||
if (PatternEngine::AnfNodeEqual(pattern_ref, expr_ref)) {
|
||||
return equiv;
|
||||
}
|
||||
|
||||
|
@ -304,7 +285,7 @@ EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const
|
|||
|
||||
// 4. here the type can be std:vector, std:list,
|
||||
// or cnode.
|
||||
if (!type_eq_(pattern_ref, expr_ref)) {
|
||||
if (!PatternEngine::CNodeTypeEqual(pattern_ref, expr_ref)) {
|
||||
MS_LOG(DEBUG) << "Type mismatch";
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -324,34 +305,62 @@ EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const
|
|||
return equiv;
|
||||
}
|
||||
|
||||
BaseRef PatternEngine::Replace(const BaseRef &pattern, const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
MS_LOG(DEBUG) << "-----[in Replace]";
|
||||
BaseRef ref = GetVar(pattern);
|
||||
BaseRef out;
|
||||
bool is_match = false;
|
||||
bool PatternEngine::AnfNodeEqual(const BaseRef &a, const BaseRef &b) {
|
||||
if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
|
||||
auto a_node = utils::cast<AnfNodePtr>(a);
|
||||
auto b_node = utils::cast<AnfNodePtr>(b);
|
||||
MS_EXCEPTION_IF_NULL(a_node);
|
||||
MS_EXCEPTION_IF_NULL(b_node);
|
||||
if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
|
||||
auto a_value_node = a_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(a_value_node);
|
||||
auto a_value = a_value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(a_value);
|
||||
auto a_prim = a_value->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(a_prim);
|
||||
|
||||
// w is var
|
||||
if (utils::isa<VarPtr>(ref)) {
|
||||
const VarPtr &var = utils::cast<VarPtr>(ref);
|
||||
auto iter = equiv->find(var);
|
||||
if (iter != equiv->end()) {
|
||||
out = iter->second;
|
||||
is_match = true;
|
||||
auto b_value_node = b_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(b_value_node);
|
||||
auto b_value = b_value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(b_value);
|
||||
auto b_prim = b_value->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(b_prim);
|
||||
|
||||
return a_prim->name() == b_prim->name();
|
||||
} else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
|
||||
auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
|
||||
if (a_value_node_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "cast value node ptr fail";
|
||||
}
|
||||
auto a_value_ptr = a_value_node_ptr->value();
|
||||
if (a_value_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "value ptr is nullptr";
|
||||
}
|
||||
|
||||
auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
|
||||
if (b_value_node_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "cast value node ptr fail";
|
||||
}
|
||||
auto b_value_ptr = b_value_node_ptr->value();
|
||||
if (b_value_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "value ptr is nullptr";
|
||||
}
|
||||
|
||||
return (*a_value_ptr) == (*b_value_ptr);
|
||||
}
|
||||
MS_LOG(DEBUG) << "check AnfNodePtr equal";
|
||||
}
|
||||
if (is_match) {
|
||||
return out;
|
||||
if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) {
|
||||
MS_LOG(DEBUG) << "check GraphPtr equal";
|
||||
}
|
||||
return a == b;
|
||||
}
|
||||
|
||||
// visitor to visit the list
|
||||
std::function<BaseRef(BaseRef)> fn = [&, this, equiv](const BaseRef &u) { return Replace(u, equiv); };
|
||||
|
||||
visitor_->SetFn(fn);
|
||||
BaseRef visit_out;
|
||||
if (!visitor_->Visit(pattern, &visit_out)) {
|
||||
return pattern;
|
||||
bool PatternEngine::CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
|
||||
// To matchCNode and Kernel's type
|
||||
if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
|
||||
return true;
|
||||
}
|
||||
return visit_out;
|
||||
return a.type() == b.type();
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -163,16 +163,12 @@ inline bool DefaultTypeEq(const BaseRef &x, const BaseRef &y) { return x.type()
|
|||
|
||||
class PatternEngine {
|
||||
public:
|
||||
PatternEngine(const std::shared_ptr<Visitor> &visitor,
|
||||
const std::function<bool(const BaseRef &, const BaseRef &)> &eq,
|
||||
const std::function<bool(const BaseRef &, const BaseRef &)> &type_eq = DefaultTypeEq)
|
||||
: visitor_(visitor), eq_(eq), type_eq_(type_eq) {}
|
||||
explicit PatternEngine(const std::shared_ptr<Visitor> &visitor) : visitor_(visitor) {}
|
||||
~PatternEngine() = default;
|
||||
|
||||
EquivPtr Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars,
|
||||
EquivPtr equiv) const;
|
||||
// Replace pattern with equivalent
|
||||
BaseRef Replace(const BaseRef &pattern, const EquivPtr &equiv) const;
|
||||
|
||||
private:
|
||||
EquivPtr AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr,
|
||||
|
@ -181,9 +177,9 @@ class PatternEngine {
|
|||
VectorRef *const values_expr) const;
|
||||
bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern,
|
||||
VectorRef *const values_expr) const;
|
||||
static bool AnfNodeEqual(const BaseRef &a, const BaseRef &b);
|
||||
static bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b);
|
||||
std::shared_ptr<Visitor> visitor_;
|
||||
std::function<bool(const BaseRef &, const BaseRef &)> eq_;
|
||||
std::function<bool(const BaseRef &, const BaseRef &)> type_eq_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
namespace std {
|
||||
|
|
|
@ -46,19 +46,33 @@ std::shared_ptr<VectorRef> ExpandList(const std::vector<BaseRef> &list) {
|
|||
return new_list;
|
||||
}
|
||||
|
||||
bool DefaultVisitor::Visit(const VectorRef &v_any, BaseRef *const visit_out) const {
|
||||
static BaseRef GetVar(const BaseRef &x) {
|
||||
if (utils::isa<AnfNodePtr>(x)) {
|
||||
auto node = utils::cast<AnfNodePtr>(x);
|
||||
MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]";
|
||||
if (node->isa<VarNode>()) {
|
||||
MS_LOG(DEBUG) << "IsVarNode " + node->cast<VarNodePtr>()->var_->ToString();
|
||||
return node->cast<VarNodePtr>()->var_;
|
||||
}
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
bool Visitor::Visit(const VectorRef &v_any, VectorRef *const values_ref, BaseRef *const visit_out) const {
|
||||
std::vector<BaseRef> out;
|
||||
(void)std::transform(v_any.begin(), v_any.end(), std::back_inserter(out),
|
||||
[this](const BaseRef &item) { return fn_(item); });
|
||||
for (const auto &element : v_any) {
|
||||
out.push_back(element);
|
||||
values_ref->push_back(GetVar(element));
|
||||
}
|
||||
if (visit_out != nullptr) {
|
||||
*visit_out = ExpandList(out);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DefaultVisitor::Visit(const BaseRef &any, BaseRef *const visit_out) const {
|
||||
bool Visitor::Visit(const BaseRef &any, VectorRef *const values_ref, BaseRef *const visit_out) const {
|
||||
if (utils::isa<Seq>(any)) {
|
||||
return Visit(utils::cast<Seq>(any), visit_out);
|
||||
return Visit(utils::cast<Seq>(any), values_ref, visit_out);
|
||||
} else if (utils::isa<AnfNodePtr>(any)) {
|
||||
auto nodeptr = utils::cast<AnfNodePtr>(any);
|
||||
AnfNodePtr output;
|
||||
|
@ -66,7 +80,7 @@ bool DefaultVisitor::Visit(const BaseRef &any, BaseRef *const visit_out) const {
|
|||
if (visit_out == nullptr) {
|
||||
p_output = nullptr;
|
||||
}
|
||||
Visit(nodeptr, fn_, p_output);
|
||||
Visit(nodeptr, values_ref, p_output);
|
||||
if (visit_out != nullptr) {
|
||||
*visit_out = output;
|
||||
}
|
||||
|
@ -76,14 +90,14 @@ bool DefaultVisitor::Visit(const BaseRef &any, BaseRef *const visit_out) const {
|
|||
return false;
|
||||
}
|
||||
|
||||
void DefaultVisitor::Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr *output) const {
|
||||
void Visitor::Visit(const AnfNodePtr &node, VectorRef *const values_ref, AnfNodePtr *output) const {
|
||||
if (node->isa<CNode>()) {
|
||||
Visit(node->cast<CNodePtr>(), fn, output);
|
||||
Visit(node->cast<CNodePtr>(), values_ref, output);
|
||||
return;
|
||||
}
|
||||
|
||||
if (node->isa<ValueNode>()) {
|
||||
Visit(node->cast<ValueNodePtr>(), fn, output);
|
||||
Visit(node->cast<ValueNodePtr>(), values_ref, output);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -92,17 +106,17 @@ void DefaultVisitor::Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr
|
|||
}
|
||||
}
|
||||
|
||||
void DefaultVisitor::Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr *output) const {
|
||||
void Visitor::Visit(const CNodePtr &cnode, VectorRef *const values_ref, AnfNodePtr *output) const {
|
||||
// if output is nullptr, it's not required to make the new CNode node.
|
||||
if (output == nullptr) {
|
||||
for (auto &inp : cnode->inputs()) {
|
||||
(void)fn(inp);
|
||||
auto var = GetVar(inp);
|
||||
values_ref->push_back(var);
|
||||
}
|
||||
|
||||
if (cnode->func_graph() != nullptr) {
|
||||
(void)fn(cnode->func_graph());
|
||||
values_ref->push_back(GetVar(cnode->func_graph()));
|
||||
} else {
|
||||
(void)fn(cnode->func_graph_as_var());
|
||||
values_ref->push_back(GetVar(cnode->func_graph_as_var()));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -110,7 +124,10 @@ void DefaultVisitor::Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr
|
|||
std::vector<AnfNodePtr> new_inputs;
|
||||
std::vector<BaseRef> after_cnode_fn;
|
||||
std::shared_ptr<VectorRef> out;
|
||||
(void)std::transform(cnode->inputs().begin(), cnode->inputs().end(), std::back_inserter(after_cnode_fn), fn);
|
||||
for (auto &input : cnode->inputs()) {
|
||||
after_cnode_fn.push_back(input);
|
||||
values_ref->push_back(GetVar(input));
|
||||
}
|
||||
if (CheckIfNeedExpand(after_cnode_fn)) {
|
||||
out = ExpandList(after_cnode_fn);
|
||||
}
|
||||
|
@ -130,13 +147,15 @@ void DefaultVisitor::Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr
|
|||
BaseRef any_fg;
|
||||
AnfNodePtr new_cnode = nullptr;
|
||||
if (cnode->func_graph() != nullptr) {
|
||||
any_fg = fn(cnode->func_graph());
|
||||
any_fg = cnode->func_graph();
|
||||
values_ref->push_back(GetVar(any_fg));
|
||||
if (!utils::isa<FuncGraphPtr>(any_fg)) {
|
||||
MS_LOG(EXCEPTION) << "VisitError, fn not return the same type FuncGraphPtr";
|
||||
}
|
||||
new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<FuncGraphPtr>(any_fg));
|
||||
} else {
|
||||
any_fg = fn(cnode->func_graph_as_var());
|
||||
any_fg = cnode->func_graph_as_var();
|
||||
values_ref->push_back(GetVar(any_fg));
|
||||
if (utils::isa<VarPtr>(any_fg)) {
|
||||
new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<VarPtr>(any_fg));
|
||||
} else if (utils::isa<FuncGraphPtr>(any_fg)) {
|
||||
|
@ -149,8 +168,9 @@ void DefaultVisitor::Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr
|
|||
*output = new_cnode;
|
||||
}
|
||||
|
||||
void DefaultVisitor::Visit(const ValueNodePtr &vnode, const VisitFn &fn, AnfNodePtr *output) const {
|
||||
const BaseRef &value = utils::cast<ValuePtr>(fn(vnode->value()));
|
||||
void Visitor::Visit(const ValueNodePtr &vnode, VectorRef *const values_ref, AnfNodePtr *output) const {
|
||||
values_ref->push_back(GetVar(vnode->value()));
|
||||
const BaseRef &value = utils::cast<ValuePtr>(vnode->value());
|
||||
if (utils::isa<ValuePtr>(value)) {
|
||||
if (output != nullptr) {
|
||||
auto ct = NewValueNode(utils::cast<ValuePtr>(value));
|
||||
|
|
|
@ -31,28 +31,15 @@
|
|||
|
||||
// namespace to support utils definition
|
||||
namespace mindspore {
|
||||
using VisitFn = std::function<BaseRef(const BaseRef &)>;
|
||||
|
||||
class Visitor {
|
||||
public:
|
||||
virtual void SetFn(VisitFn fn) = 0;
|
||||
virtual bool Visit(const BaseRef &e, BaseRef *out) const = 0;
|
||||
virtual bool Visit(const VectorRef &e, BaseRef *out) const = 0;
|
||||
virtual ~Visitor() = default;
|
||||
};
|
||||
|
||||
class DefaultVisitor : public Visitor {
|
||||
public:
|
||||
DefaultVisitor() : fn_(nullptr) {}
|
||||
~DefaultVisitor() override = default;
|
||||
void SetFn(VisitFn fn) override { fn_ = fn; };
|
||||
bool Visit(const VectorRef &e, BaseRef *out) const override;
|
||||
bool Visit(const BaseRef &e, BaseRef *out) const override;
|
||||
void Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr *output) const;
|
||||
void Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr *output) const;
|
||||
void Visit(const ValueNodePtr &vnode, const VisitFn &fn, AnfNodePtr *output) const;
|
||||
|
||||
VisitFn fn_;
|
||||
Visitor() {}
|
||||
~Visitor() = default;
|
||||
bool Visit(const VectorRef &e, VectorRef *const values_ref, BaseRef *out) const;
|
||||
bool Visit(const BaseRef &e, VectorRef *const values_ref, BaseRef *out) const;
|
||||
void Visit(const AnfNodePtr &node, VectorRef *const values_ref, AnfNodePtr *output) const;
|
||||
void Visit(const CNodePtr &cnode, VectorRef *const values_ref, AnfNodePtr *output) const;
|
||||
void Visit(const ValueNodePtr &vnode, VectorRef *const values_ref, AnfNodePtr *output) const;
|
||||
};
|
||||
|
||||
std::shared_ptr<VectorRef> ExpandList(const std::vector<BaseRef> &list);
|
||||
|
|
|
@ -199,9 +199,7 @@ EquivPtr TfliteLstmCellFusion::MatchGraph(const FuncGraphPtr &func_graph, const
|
|||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(pattern != nullptr);
|
||||
auto return_node = func_graph->get_return();
|
||||
PatternEngine pattern_engine(PatternEngine(std::make_shared<DefaultVisitor>(),
|
||||
std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual),
|
||||
std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual)));
|
||||
PatternEngine pattern_engine(std::make_shared<Visitor>());
|
||||
auto empty_equiv = std::make_shared<Equiv>();
|
||||
EquivPtr equiv = pattern_engine.Match(pattern, return_node, *primitive_vars, empty_equiv);
|
||||
return equiv;
|
||||
|
|
|
@ -33,7 +33,7 @@ bool Equal(const BaseRef &a, const BaseRef &b) { return a == b; }
|
|||
class TestMatchEngine : public UT::Common {
|
||||
public:
|
||||
TestMatchEngine()
|
||||
: TU(std::make_shared<DefaultVisitor>(), std::function<bool(const BaseRef &, const BaseRef &)>(Equal)) {
|
||||
: TU(std::make_shared<Visitor>()) {
|
||||
equiv_null = std::make_shared<Equiv>();
|
||||
};
|
||||
|
||||
|
@ -215,17 +215,4 @@ TEST_F(TestMatchEngine, Match_CondVar) {
|
|||
equiv_null);
|
||||
ASSERT_EQ(d, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(TestMatchEngine, Match_Reify) {
|
||||
VarPtr v1 = std::make_shared<Var>();
|
||||
VarPtr sv = std::make_shared<SeqVar>();
|
||||
|
||||
BaseRef t;
|
||||
|
||||
equiv_null->clear();
|
||||
(*equiv_null)[sv] = BaseRef(std::make_shared<Seq>(PatternListType{3, 4}));
|
||||
t = TU.Replace(VectorRef({1, 2, sv}), equiv_null);
|
||||
ASSERT_EQ(t, BaseRef(VectorRef({1, 2, 3, 4})));
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue