!17614 pass optimizer

From: @zhupuxu
Reviewed-by: @zhoufeng54
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-06-08 09:17:33 +08:00 committed by Gitee
commit 86cb781ce0
11 changed files with 205 additions and 122 deletions

View File

@ -32,9 +32,7 @@ namespace opt {
PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph)
: NodePass(name),
multigraph_(multigraph),
pattern_engine_(PatternEngine(std::make_shared<DefaultVisitor>(),
std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual),
std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))),
pattern_engine_(PatternEngine(std::make_shared<Visitor>())),
primitive_vars_(std::make_shared<PrimitiveVarMap>()) {}
const BaseRef PatternProcessPass::DefinePattern() const {
@ -53,11 +51,14 @@ AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNode
Build();
}
auto empty_equiv = std::make_shared<Equiv>();
MS_EXCEPTION_IF_NULL(primitive_vars_);
EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, empty_equiv);
if (equiv != nullptr && !equiv->empty()) {
return Process(func_graph, node, equiv);
auto primitive = GetCNodePrimitive(pattern_);
if (IsPrimitiveCNode(node, primitive)) {
auto empty_equiv = std::make_shared<Equiv>();
MS_EXCEPTION_IF_NULL(primitive_vars_);
EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, empty_equiv);
if (equiv != nullptr && !equiv->empty()) {
return Process(func_graph, node, equiv);
}
}
return nullptr;
}

View File

@ -55,9 +55,7 @@ class MultipleOutputPatternProcessPass : public PatternProcessPass {
public:
explicit MultipleOutputPatternProcessPass(const std::string &name = "", bool multigraph = true)
: PatternProcessPass(name, multigraph),
child_pattern_engine_(PatternEngine(std::make_shared<DefaultVisitor>(),
std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual),
std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))),
child_pattern_engine_(PatternEngine(std::make_shared<Visitor>())),
child_primitive_vars_(std::make_shared<PrimitiveVarMap>()) {}
~MultipleOutputPatternProcessPass() override = default;
virtual BaseRef DefineAnotherPattern() const = 0;

View File

@ -24,6 +24,9 @@
namespace mindspore {
namespace opt {
class CacheManager;
using CacheManagerPtr = std::shared_ptr<CacheManager>;
// @brief ANF Graph level optimization base pass
class Pass {
public:
@ -31,9 +34,12 @@ class Pass {
virtual ~Pass() = default;
virtual bool Run(const FuncGraphPtr &func_graph) = 0;
virtual std::string name() const { return name_; }
void SetCacheManager(const CacheManagerPtr &cm) { cache_manager_ = cm; }
const CacheManagerPtr &GetCacheManager() const { return cache_manager_; }
private:
const std::string name_;
CacheManagerPtr cache_manager_;
};
using PassPtr = std::shared_ptr<Pass>;
} // namespace opt

View File

@ -23,9 +23,73 @@
#include "ir/manager.h"
#include "utils/ms_context.h"
#include "debug/anf_ir_dump.h"
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace opt {
void CacheManager::Update(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto type_iter = type_map_.find(node);
auto shape_iter = shape_map_.find(node);
if (type_iter != type_map_.end()) {
type_map_.erase(type_iter);
}
if (shape_iter != shape_map_.end()) {
shape_map_.erase(shape_iter);
}
}
TypeId CacheManager::GetOutputType(const AnfNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node);
auto iter = type_map_.find(node);
if (iter != type_map_.end()) {
auto types = iter->second;
auto type_iter = types.find(index);
if (type_iter != types.end()) {
return type_iter->second;
}
return kTypeUnknown;
}
auto output_nums = AnfAlgo::GetOutputTensorNum(node);
std::map<size_t, TypeId> index_to_types;
TypeId result = kTypeUnknown;
for (size_t i = 0; i < output_nums; i++) {
auto output_type = AnfAlgo::GetOutputInferDataType(node, i);
index_to_types.emplace(i, output_type);
if (index == i) {
result = output_type;
}
}
type_map_.emplace(node, index_to_types);
return result;
}
std::vector<size_t> CacheManager::GetOutputShape(const AnfNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node);
auto iter = shape_map_.find(node);
if (iter != shape_map_.end()) {
auto shapes = iter->second;
auto shape_iter = shapes.find(index);
if (shape_iter != shapes.end()) {
return shape_iter->second;
}
return {};
}
auto output_nums = AnfAlgo::GetOutputTensorNum(node);
std::map<size_t, std::vector<size_t>> index_to_shapes;
std::vector<size_t> result = {};
for (size_t i = 0; i < output_nums; i++) {
auto output_shape = AnfAlgo::GetOutputInferShape(node, i);
index_to_shapes.emplace(i, output_shape);
if (index == i) {
result = output_shape;
}
}
shape_map_.emplace(node, index_to_shapes);
return result;
}
const std::vector<PassPtr> &PassManager::Passes() const { return passes_; }
void PassManager::AddPass(const PassPtr &pass) {
@ -84,6 +148,7 @@ bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr>
size_t num = 0;
for (const auto &pass : passes) {
if (pass != nullptr) {
pass->SetCacheManager(cache_manager_);
changed = RunPass(func_graph, num, pass) || changed;
DumpPassIR(func_graph, GetPassFullname(num, pass));
num++;

View File

@ -20,17 +20,32 @@
#include <vector>
#include <string>
#include <memory>
#include <map>
#include "backend/optimizer/common/pass.h"
#include "backend/optimizer/common/node_pass.h"
namespace mindspore {
namespace opt {
class CacheManager {
public:
CacheManager() {}
~CacheManager() = default;
void Update(const AnfNodePtr &node);
TypeId GetOutputType(const AnfNodePtr &node, size_t index);
std::vector<size_t> GetOutputShape(const AnfNodePtr &node, size_t index);
private:
std::map<AnfNodePtr, std::map<size_t, TypeId>> type_map_;
std::map<AnfNodePtr, std::map<size_t, std::vector<size_t>>> shape_map_;
};
using CacheManagerPtr = std::shared_ptr<CacheManager>;
// @brief For optimization passes management
class PassManager {
public:
explicit PassManager(const std::string &name = "pm", bool run_only_once = true)
: name_(name), passes_{}, run_only_once_(run_only_once) {}
: name_(name), passes_{}, run_only_once_(run_only_once), cache_manager_(std::make_shared<CacheManager>()) {}
virtual ~PassManager() = default;
// Get all the passes added by AddPass
const std::vector<PassPtr> &Passes() const;
@ -57,6 +72,7 @@ class PassManager {
const std::string name_;
std::vector<PassPtr> passes_;
bool run_only_once_;
CacheManagerPtr cache_manager_;
};
using PassManagerPtr = std::shared_ptr<PassManager>;
} // namespace opt

View File

@ -25,6 +25,7 @@
#include "ir/anf.h"
#include "utils/convert_utils_base.h"
#include "utils/overload.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
static int GetNextTag() {
@ -156,33 +157,13 @@ bool PatternEngine::ToVector(const VectorRef &pattern_ref, const VectorRef &expr
bool PatternEngine::ToVector(const BaseRef &pattern_ref, const BaseRef &expr_ref, VectorRef *const values_pattern,
VectorRef *const values_expr) const {
MS_EXCEPTION_IF_NULL(values_expr);
// visitor to visite the list
auto appender_pattern = [](VectorRef &values) {
std::function<BaseRef(const BaseRef &)> fn = [&](const BaseRef &u) {
values.push_back(GetVar(u));
return u;
};
return fn;
};
visitor_->SetFn(appender_pattern(*values_pattern));
MS_LOG(DEBUG) << "visit pattern_ref";
bool success = visitor_->Visit(pattern_ref, nullptr);
bool success = visitor_->Visit(pattern_ref, values_pattern, nullptr);
if (!success) {
return false;
}
auto appender_expr = [](VectorRef &values) {
std::function<BaseRef(const BaseRef &)> fn = [&](const BaseRef &u) {
values.push_back(u);
return u;
};
return fn;
};
visitor_->SetFn(appender_expr(*values_expr));
MS_LOG(DEBUG) << "visit expr_ref";
return visitor_->Visit(expr_ref, nullptr);
return visitor_->Visit(expr_ref, values_expr, nullptr);
}
static int GetSVarStartIndex(const VectorRef &values) {
@ -292,7 +273,7 @@ EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const
}
// 2. check equal
if (eq_(pattern_ref, expr_ref)) {
if (PatternEngine::AnfNodeEqual(pattern_ref, expr_ref)) {
return equiv;
}
@ -304,7 +285,7 @@ EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const
// 4. here the type can be std:vector, std:list,
// or cnode.
if (!type_eq_(pattern_ref, expr_ref)) {
if (!PatternEngine::CNodeTypeEqual(pattern_ref, expr_ref)) {
MS_LOG(DEBUG) << "Type mismatch";
return nullptr;
}
@ -324,34 +305,62 @@ EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const
return equiv;
}
BaseRef PatternEngine::Replace(const BaseRef &pattern, const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(equiv);
MS_LOG(DEBUG) << "-----[in Replace]";
BaseRef ref = GetVar(pattern);
BaseRef out;
bool is_match = false;
bool PatternEngine::AnfNodeEqual(const BaseRef &a, const BaseRef &b) {
if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
auto a_node = utils::cast<AnfNodePtr>(a);
auto b_node = utils::cast<AnfNodePtr>(b);
MS_EXCEPTION_IF_NULL(a_node);
MS_EXCEPTION_IF_NULL(b_node);
if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
auto a_value_node = a_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(a_value_node);
auto a_value = a_value_node->value();
MS_EXCEPTION_IF_NULL(a_value);
auto a_prim = a_value->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(a_prim);
// w is var
if (utils::isa<VarPtr>(ref)) {
const VarPtr &var = utils::cast<VarPtr>(ref);
auto iter = equiv->find(var);
if (iter != equiv->end()) {
out = iter->second;
is_match = true;
auto b_value_node = b_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(b_value_node);
auto b_value = b_value_node->value();
MS_EXCEPTION_IF_NULL(b_value);
auto b_prim = b_value->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(b_prim);
return a_prim->name() == b_prim->name();
} else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
if (a_value_node_ptr == nullptr) {
MS_LOG(EXCEPTION) << "cast value node ptr fail";
}
auto a_value_ptr = a_value_node_ptr->value();
if (a_value_ptr == nullptr) {
MS_LOG(EXCEPTION) << "value ptr is nullptr";
}
auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
if (b_value_node_ptr == nullptr) {
MS_LOG(EXCEPTION) << "cast value node ptr fail";
}
auto b_value_ptr = b_value_node_ptr->value();
if (b_value_ptr == nullptr) {
MS_LOG(EXCEPTION) << "value ptr is nullptr";
}
return (*a_value_ptr) == (*b_value_ptr);
}
MS_LOG(DEBUG) << "check AnfNodePtr equal";
}
if (is_match) {
return out;
if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) {
MS_LOG(DEBUG) << "check GraphPtr equal";
}
return a == b;
}
// visitor to visit the list
std::function<BaseRef(BaseRef)> fn = [&, this, equiv](const BaseRef &u) { return Replace(u, equiv); };
visitor_->SetFn(fn);
BaseRef visit_out;
if (!visitor_->Visit(pattern, &visit_out)) {
return pattern;
bool PatternEngine::CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
// To matchCNode and Kernel's type
if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
return true;
}
return visit_out;
return a.type() == b.type();
}
} // namespace mindspore

View File

@ -163,16 +163,12 @@ inline bool DefaultTypeEq(const BaseRef &x, const BaseRef &y) { return x.type()
class PatternEngine {
public:
PatternEngine(const std::shared_ptr<Visitor> &visitor,
const std::function<bool(const BaseRef &, const BaseRef &)> &eq,
const std::function<bool(const BaseRef &, const BaseRef &)> &type_eq = DefaultTypeEq)
: visitor_(visitor), eq_(eq), type_eq_(type_eq) {}
explicit PatternEngine(const std::shared_ptr<Visitor> &visitor) : visitor_(visitor) {}
~PatternEngine() = default;
EquivPtr Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars,
EquivPtr equiv) const;
// Replace pattern with equivalent
BaseRef Replace(const BaseRef &pattern, const EquivPtr &equiv) const;
private:
EquivPtr AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr,
@ -181,9 +177,9 @@ class PatternEngine {
VectorRef *const values_expr) const;
bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern,
VectorRef *const values_expr) const;
static bool AnfNodeEqual(const BaseRef &a, const BaseRef &b);
static bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b);
std::shared_ptr<Visitor> visitor_;
std::function<bool(const BaseRef &, const BaseRef &)> eq_;
std::function<bool(const BaseRef &, const BaseRef &)> type_eq_;
};
} // namespace mindspore
namespace std {

View File

@ -46,19 +46,33 @@ std::shared_ptr<VectorRef> ExpandList(const std::vector<BaseRef> &list) {
return new_list;
}
bool DefaultVisitor::Visit(const VectorRef &v_any, BaseRef *const visit_out) const {
static BaseRef GetVar(const BaseRef &x) {
if (utils::isa<AnfNodePtr>(x)) {
auto node = utils::cast<AnfNodePtr>(x);
MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]";
if (node->isa<VarNode>()) {
MS_LOG(DEBUG) << "IsVarNode " + node->cast<VarNodePtr>()->var_->ToString();
return node->cast<VarNodePtr>()->var_;
}
}
return x;
}
bool Visitor::Visit(const VectorRef &v_any, VectorRef *const values_ref, BaseRef *const visit_out) const {
std::vector<BaseRef> out;
(void)std::transform(v_any.begin(), v_any.end(), std::back_inserter(out),
[this](const BaseRef &item) { return fn_(item); });
for (const auto &element : v_any) {
out.push_back(element);
values_ref->push_back(GetVar(element));
}
if (visit_out != nullptr) {
*visit_out = ExpandList(out);
}
return true;
}
bool DefaultVisitor::Visit(const BaseRef &any, BaseRef *const visit_out) const {
bool Visitor::Visit(const BaseRef &any, VectorRef *const values_ref, BaseRef *const visit_out) const {
if (utils::isa<Seq>(any)) {
return Visit(utils::cast<Seq>(any), visit_out);
return Visit(utils::cast<Seq>(any), values_ref, visit_out);
} else if (utils::isa<AnfNodePtr>(any)) {
auto nodeptr = utils::cast<AnfNodePtr>(any);
AnfNodePtr output;
@ -66,7 +80,7 @@ bool DefaultVisitor::Visit(const BaseRef &any, BaseRef *const visit_out) const {
if (visit_out == nullptr) {
p_output = nullptr;
}
Visit(nodeptr, fn_, p_output);
Visit(nodeptr, values_ref, p_output);
if (visit_out != nullptr) {
*visit_out = output;
}
@ -76,14 +90,14 @@ bool DefaultVisitor::Visit(const BaseRef &any, BaseRef *const visit_out) const {
return false;
}
void DefaultVisitor::Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr *output) const {
void Visitor::Visit(const AnfNodePtr &node, VectorRef *const values_ref, AnfNodePtr *output) const {
if (node->isa<CNode>()) {
Visit(node->cast<CNodePtr>(), fn, output);
Visit(node->cast<CNodePtr>(), values_ref, output);
return;
}
if (node->isa<ValueNode>()) {
Visit(node->cast<ValueNodePtr>(), fn, output);
Visit(node->cast<ValueNodePtr>(), values_ref, output);
return;
}
@ -92,17 +106,17 @@ void DefaultVisitor::Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr
}
}
void DefaultVisitor::Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr *output) const {
void Visitor::Visit(const CNodePtr &cnode, VectorRef *const values_ref, AnfNodePtr *output) const {
// if output is nullptr, it's not required to make the new CNode node.
if (output == nullptr) {
for (auto &inp : cnode->inputs()) {
(void)fn(inp);
auto var = GetVar(inp);
values_ref->push_back(var);
}
if (cnode->func_graph() != nullptr) {
(void)fn(cnode->func_graph());
values_ref->push_back(GetVar(cnode->func_graph()));
} else {
(void)fn(cnode->func_graph_as_var());
values_ref->push_back(GetVar(cnode->func_graph_as_var()));
}
return;
}
@ -110,7 +124,10 @@ void DefaultVisitor::Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr
std::vector<AnfNodePtr> new_inputs;
std::vector<BaseRef> after_cnode_fn;
std::shared_ptr<VectorRef> out;
(void)std::transform(cnode->inputs().begin(), cnode->inputs().end(), std::back_inserter(after_cnode_fn), fn);
for (auto &input : cnode->inputs()) {
after_cnode_fn.push_back(input);
values_ref->push_back(GetVar(input));
}
if (CheckIfNeedExpand(after_cnode_fn)) {
out = ExpandList(after_cnode_fn);
}
@ -130,13 +147,15 @@ void DefaultVisitor::Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr
BaseRef any_fg;
AnfNodePtr new_cnode = nullptr;
if (cnode->func_graph() != nullptr) {
any_fg = fn(cnode->func_graph());
any_fg = cnode->func_graph();
values_ref->push_back(GetVar(any_fg));
if (!utils::isa<FuncGraphPtr>(any_fg)) {
MS_LOG(EXCEPTION) << "VisitError, fn not return the same type FuncGraphPtr";
}
new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<FuncGraphPtr>(any_fg));
} else {
any_fg = fn(cnode->func_graph_as_var());
any_fg = cnode->func_graph_as_var();
values_ref->push_back(GetVar(any_fg));
if (utils::isa<VarPtr>(any_fg)) {
new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<VarPtr>(any_fg));
} else if (utils::isa<FuncGraphPtr>(any_fg)) {
@ -149,8 +168,9 @@ void DefaultVisitor::Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr
*output = new_cnode;
}
void DefaultVisitor::Visit(const ValueNodePtr &vnode, const VisitFn &fn, AnfNodePtr *output) const {
const BaseRef &value = utils::cast<ValuePtr>(fn(vnode->value()));
void Visitor::Visit(const ValueNodePtr &vnode, VectorRef *const values_ref, AnfNodePtr *output) const {
values_ref->push_back(GetVar(vnode->value()));
const BaseRef &value = utils::cast<ValuePtr>(vnode->value());
if (utils::isa<ValuePtr>(value)) {
if (output != nullptr) {
auto ct = NewValueNode(utils::cast<ValuePtr>(value));

View File

@ -31,28 +31,15 @@
// namespace to support utils definition
namespace mindspore {
using VisitFn = std::function<BaseRef(const BaseRef &)>;
class Visitor {
public:
virtual void SetFn(VisitFn fn) = 0;
virtual bool Visit(const BaseRef &e, BaseRef *out) const = 0;
virtual bool Visit(const VectorRef &e, BaseRef *out) const = 0;
virtual ~Visitor() = default;
};
class DefaultVisitor : public Visitor {
public:
DefaultVisitor() : fn_(nullptr) {}
~DefaultVisitor() override = default;
void SetFn(VisitFn fn) override { fn_ = fn; };
bool Visit(const VectorRef &e, BaseRef *out) const override;
bool Visit(const BaseRef &e, BaseRef *out) const override;
void Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr *output) const;
void Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr *output) const;
void Visit(const ValueNodePtr &vnode, const VisitFn &fn, AnfNodePtr *output) const;
VisitFn fn_;
Visitor() {}
~Visitor() = default;
bool Visit(const VectorRef &e, VectorRef *const values_ref, BaseRef *out) const;
bool Visit(const BaseRef &e, VectorRef *const values_ref, BaseRef *out) const;
void Visit(const AnfNodePtr &node, VectorRef *const values_ref, AnfNodePtr *output) const;
void Visit(const CNodePtr &cnode, VectorRef *const values_ref, AnfNodePtr *output) const;
void Visit(const ValueNodePtr &vnode, VectorRef *const values_ref, AnfNodePtr *output) const;
};
std::shared_ptr<VectorRef> ExpandList(const std::vector<BaseRef> &list);

View File

@ -199,9 +199,7 @@ EquivPtr TfliteLstmCellFusion::MatchGraph(const FuncGraphPtr &func_graph, const
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(pattern != nullptr);
auto return_node = func_graph->get_return();
PatternEngine pattern_engine(PatternEngine(std::make_shared<DefaultVisitor>(),
std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual),
std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual)));
PatternEngine pattern_engine(std::make_shared<Visitor>());
auto empty_equiv = std::make_shared<Equiv>();
EquivPtr equiv = pattern_engine.Match(pattern, return_node, *primitive_vars, empty_equiv);
return equiv;

View File

@ -33,7 +33,7 @@ bool Equal(const BaseRef &a, const BaseRef &b) { return a == b; }
class TestMatchEngine : public UT::Common {
public:
TestMatchEngine()
: TU(std::make_shared<DefaultVisitor>(), std::function<bool(const BaseRef &, const BaseRef &)>(Equal)) {
: TU(std::make_shared<Visitor>()) {
equiv_null = std::make_shared<Equiv>();
};
@ -215,17 +215,4 @@ TEST_F(TestMatchEngine, Match_CondVar) {
equiv_null);
ASSERT_EQ(d, nullptr);
}
TEST_F(TestMatchEngine, Match_Reify) {
VarPtr v1 = std::make_shared<Var>();
VarPtr sv = std::make_shared<SeqVar>();
BaseRef t;
equiv_null->clear();
(*equiv_null)[sv] = BaseRef(std::make_shared<Seq>(PatternListType{3, 4}));
t = TU.Replace(VectorRef({1, 2, sv}), equiv_null);
ASSERT_EQ(t, BaseRef(VectorRef({1, 2, 3, 4})));
}
} // namespace mindspore