forked from mindspore-Ecosystem/mindspore
Added a Pattern Matcher class to help with future optimization implementations. Includes changes to barnch_culling to show how to use the new Pattern Matcher infrastructure.
This commit is contained in:
parent
5cba231ba9
commit
8eaea74407
|
@ -0,0 +1,306 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_
|
||||
#define MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "operator/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
///
|
||||
/// Base class for all recognizable patterns.
|
||||
/// We implement an Expression Template approach using static polymorphism based on
|
||||
/// the Curiously Recurring Template Pattern (CRTP) which "achieves a similar effect
|
||||
/// to the use of virtual functions without the costs..." as described in:
|
||||
/// https://en.wikipedia.org/wiki/Expression_templates and
|
||||
/// https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern
|
||||
/// The TryCapture function tries to capture the pattern with the given node.
|
||||
/// The GetNode function builds a new node using the captured values.
|
||||
///
|
||||
|
||||
template <typename T>
|
||||
class PBase {
|
||||
public:
|
||||
const T &get_object() const { return *static_cast<const T *>(this); }
|
||||
|
||||
template <typename TN>
|
||||
bool TryCapture(const TN &value) const {
|
||||
get_object().Reset();
|
||||
return get_object().TryCapture_(value);
|
||||
}
|
||||
|
||||
using Internal = T;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class PIsEqual {
|
||||
public:
|
||||
bool operator()(const T &lhs, const T &rhs) const { return lhs == rhs; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class PatternNode : public PBase<PatternNode<T> > {
|
||||
public:
|
||||
T GetNode(const AnfNodePtr &node) const {
|
||||
if (!captured_) {
|
||||
MS_EXCEPTION(ValueError) << "A Pattern wasn't captured for this Token before the call to GetNode.";
|
||||
}
|
||||
return captured_node_;
|
||||
}
|
||||
|
||||
bool TryCapture_(const T &node) const {
|
||||
if (!captured_) {
|
||||
captured_node_ = node;
|
||||
captured_ = true;
|
||||
return true;
|
||||
}
|
||||
return PIsEqual<T>()(captured_node_, node);
|
||||
}
|
||||
|
||||
void Reset() const { captured_ = false; }
|
||||
using Internal = const PatternNode<T> &;
|
||||
|
||||
protected:
|
||||
mutable T captured_node_;
|
||||
mutable bool captured_{false};
|
||||
};
|
||||
|
||||
template <typename T, typename T2>
|
||||
class PBinOperation : public PBase<PBinOperation<T, T2> > {
|
||||
public:
|
||||
PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y) : prim_(prim), x_(x), y_(y) {}
|
||||
|
||||
AnfNodePtr GetNode(const AnfNodePtr &node) const {
|
||||
AnfNodePtr lhs = x_.GetNode(node->func_graph());
|
||||
AnfNodePtr rhs = y_.GetNode(node->func_graph());
|
||||
AnfNodePtrList list = {prim_->cast<AnfNodePtr>(), lhs, rhs};
|
||||
return NewCNode(list, node->func_graph());
|
||||
}
|
||||
|
||||
bool TryCapture_(const AnfNodePtr &node) const {
|
||||
if (IsPrimitiveCNode(node, prim_)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto inputs = cnode->inputs();
|
||||
if (inputs.size() == 3) {
|
||||
// Binary Prim assumes only two inputs
|
||||
if (!x_.TryCapture_(inputs[1]) || !y_.TryCapture_(inputs[2])) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void Reset() const {
|
||||
x_.Reset();
|
||||
y_.Reset();
|
||||
}
|
||||
|
||||
private:
|
||||
const PrimitivePtr prim_;
|
||||
typename T::Internal x_;
|
||||
typename T2::Internal y_;
|
||||
};
|
||||
|
||||
///
|
||||
/// Helper functions to apply a pattern function on all elements of a tuple
|
||||
///
|
||||
namespace tuple_utils {
|
||||
template <bool stop, size_t Index, typename Func>
|
||||
struct apply_func_tuple_item {
|
||||
template <typename TTuple>
|
||||
static void apply(Func *func, const TTuple &tuple) {
|
||||
(*func)(Index, std::get<Index>(tuple));
|
||||
apply_func_tuple_item<(Index + 1) == std::tuple_size<TTuple>::value, (Index + 1), Func>::apply(func, tuple);
|
||||
}
|
||||
};
|
||||
|
||||
template <size_t Index, typename Func>
|
||||
struct apply_func_tuple_item<true, Index, Func> {
|
||||
template <typename TTuple>
|
||||
static void apply(Func *func, const TTuple &tuple) {}
|
||||
};
|
||||
|
||||
template <typename Func, typename TTuple>
|
||||
inline void apply_func_tuple(Func *func, const TTuple &tuple) {
|
||||
apply_func_tuple_item<std::tuple_size<TTuple>::value == 0, 0, Func>::apply(func, tuple);
|
||||
}
|
||||
|
||||
struct PTupleResetCapture {
|
||||
template <typename T>
|
||||
void operator()(size_t i, const T &pattern) const {
|
||||
pattern.Reset();
|
||||
}
|
||||
};
|
||||
|
||||
struct PTupleCapture {
|
||||
explicit PTupleCapture(const AnfNodePtrList tuple) : tuple_(tuple) {}
|
||||
|
||||
template <typename TPattern>
|
||||
void operator()(size_t i, const TPattern &pattern) {
|
||||
// Check if the first node is a Primitive
|
||||
if (i == 0 && tuple_[i]->isa<Primitive>()) {
|
||||
auto prim = tuple_[i]->cast<PrimitivePtr>();
|
||||
if (tuple_[i] != pattern.GetNode(tuple_[i])) {
|
||||
captured_ = false;
|
||||
}
|
||||
} else {
|
||||
captured_ = captured_ && pattern.TryCapture_(tuple_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
const AnfNodePtrList tuple_;
|
||||
bool captured_{true};
|
||||
};
|
||||
|
||||
struct PTupleGetNode {
|
||||
explicit PTupleGetNode(const AnfNodePtr &node) : node_(node) {}
|
||||
|
||||
template <typename TPattern>
|
||||
void operator()(size_t, const TPattern &pattern) {
|
||||
args_.push_back(pattern.GetNode(node_));
|
||||
}
|
||||
|
||||
const AnfNodePtr &node_;
|
||||
std::vector<AnfNodePtr> args_;
|
||||
};
|
||||
} // namespace tuple_utils
|
||||
|
||||
template <typename... TArgs>
|
||||
class PCNode : public PBase<PCNode<TArgs...> > {
|
||||
public:
|
||||
explicit PCNode(const TArgs &... args) : args_(args...) {}
|
||||
|
||||
AnfNodePtr GetNode(const AnfNodePtr &node) const {
|
||||
tuple_utils::PTupleGetNode get_node(node);
|
||||
tuple_utils::apply_func_tuple(&get_node, args_);
|
||||
return NewCNode(get_node.args_, node->func_graph());
|
||||
}
|
||||
|
||||
bool TryCapture_(const AnfNodePtr &node) const {
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto inputs = cnode->inputs();
|
||||
if (inputs.size() != sizeof...(TArgs)) {
|
||||
return false;
|
||||
}
|
||||
tuple_utils::PTupleCapture capture_func(inputs);
|
||||
tuple_utils::apply_func_tuple(&capture_func, args_);
|
||||
return capture_func.captured_;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void Reset() const {
|
||||
tuple_utils::PTupleResetCapture reset;
|
||||
tuple_utils::apply_func_tuple(&reset, args_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::tuple<typename TArgs::Internal...> args_;
|
||||
};
|
||||
|
||||
template <typename... TArgs>
|
||||
class PPrimitive : public PBase<PPrimitive<TArgs...> > {
|
||||
public:
|
||||
explicit PPrimitive(const PrimitivePtr &prim, const TArgs &... args) : prim_(prim), args_(args...) {}
|
||||
|
||||
AnfNodePtr GetNode(const AnfNodePtr &node) const {
|
||||
tuple_utils::PTupleGetNode get_node(node);
|
||||
tuple_utils::apply_func_tuple(&get_node, args_);
|
||||
auto prim_cnode = get_node.args_;
|
||||
prim_cnode.insert(prim_cnode.begin(), NewValueNode(prim_));
|
||||
return NewCNode(prim_cnode, node->func_graph());
|
||||
}
|
||||
|
||||
bool TryCapture_(const AnfNodePtr &node) const {
|
||||
if (IsPrimitiveCNode(node, prim_)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto inputs = cnode->inputs();
|
||||
if ((inputs.size() - 1) != sizeof...(TArgs)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
AnfNodePtrList rest(inputs.begin() + 1, inputs.end());
|
||||
tuple_utils::PTupleCapture capture_func(rest);
|
||||
tuple_utils::apply_func_tuple(&capture_func, args_);
|
||||
|
||||
return capture_func.captured_;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void Reset() const {
|
||||
tuple_utils::PTupleResetCapture reset;
|
||||
tuple_utils::apply_func_tuple(&reset, args_);
|
||||
}
|
||||
|
||||
private:
|
||||
const PrimitivePtr prim_;
|
||||
std::tuple<typename TArgs::Internal...> args_;
|
||||
};
|
||||
|
||||
// Macro for binary operation functions
|
||||
#define BIN_OPERATION_PATTERN(Operator, MSPrimitive) \
|
||||
template <typename T, typename T2> \
|
||||
inline PBinOperation<T, T2> Operator(const PBase<T> &x, const PBase<T2> &y) { \
|
||||
return PBinOperation(MSPrimitive, x.get_object(), y.get_object()); \
|
||||
}
|
||||
|
||||
// Arithmetic operations
|
||||
BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd);
|
||||
BIN_OPERATION_PATTERN(operator*, prim::kPrimMul);
|
||||
|
||||
// Macros for match and replace
|
||||
#define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \
|
||||
if ((CaptureNode).TryCapture(OrigNode)) { \
|
||||
return (ReplaceWith).GetNode(OrigNode); \
|
||||
}
|
||||
|
||||
#define MATCH_REPLACE_IF(OrigNode, CaptureNode, ReplaceWith, Condition) \
|
||||
if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \
|
||||
return (ReplaceWith).GetNode(OrigNode); \
|
||||
}
|
||||
|
||||
#define MATCH_REPLACE_IF_ELSE(OrigNode, CaptureNode, ReplaceWith, Condition, ElseNode) \
|
||||
if ((CaptureNode).TryCapture(OrigNode)) { \
|
||||
if ((Condition)) { \
|
||||
return (ReplaceWith).GetNode(OrigNode); \
|
||||
} \
|
||||
return (ElseNode).GetNode(OrigNode); \
|
||||
}
|
||||
|
||||
#define MATCH_REPLACE_LAMBDA(OrigNode, CaptureNode, Lambda) \
|
||||
if ((CaptureNode).TryCapture(OrigNode)) { \
|
||||
return (Lambda)(); \
|
||||
}
|
||||
|
||||
#define MATCH_REPLACE_LAMBDA_IF(OrigNode, CaptureNode, Lambda, Condition) \
|
||||
if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \
|
||||
return (Lambda)(); \
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // #ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_
|
|
@ -26,141 +26,61 @@
|
|||
#include "ir/func_graph.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "operator/ops.h"
|
||||
#include "ir/pattern_matcher.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
// {prim::kPrimSwitch, true, X, Y}
|
||||
// {prim::kPrimSwitch, false, X, Y}
|
||||
class SwitchSimplify : public AnfVisitor {
|
||||
class SwitchSimplify {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
Reset();
|
||||
auto getx = [this](const AnfNodePtr &node) -> bool {
|
||||
this->x_ = node;
|
||||
return true;
|
||||
};
|
||||
auto gety = [this](const AnfNodePtr &node) -> bool {
|
||||
this->y_ = node;
|
||||
return true;
|
||||
};
|
||||
AnfVisitor::Match(prim::kPrimSwitch, {IsValueNode<BoolImm>, getx, gety})(node);
|
||||
|
||||
// simplify the switch
|
||||
if (is_match_) {
|
||||
if (cond_) {
|
||||
return x_;
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
PatternNode<AnfNodePtr> cond, true_br, false_br;
|
||||
auto SwitchSimplLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr {
|
||||
auto cond_value_ = GetValue<bool>(GetValueNode(cond.GetNode(node)));
|
||||
if (cond_value_) {
|
||||
return true_br.GetNode(node);
|
||||
}
|
||||
return y_;
|
||||
}
|
||||
return false_br.GetNode(node);
|
||||
};
|
||||
|
||||
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), SwitchSimplLambda,
|
||||
IsValueNode<BoolImm>(cond.GetNode(node)));
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &node) override {
|
||||
if (!is_match_ && IsValueNode<BoolImm>(node)) {
|
||||
cond_ = GetValue<bool>(GetValueNode(node));
|
||||
is_match_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
x_ = nullptr;
|
||||
y_ = nullptr;
|
||||
cond_ = false;
|
||||
is_match_ = false;
|
||||
}
|
||||
|
||||
private:
|
||||
bool is_match_{false}, cond_{false};
|
||||
AnfNodePtr x_{nullptr}, y_{nullptr};
|
||||
};
|
||||
|
||||
// {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} =>
|
||||
// {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}}
|
||||
class FloatTupleGetItemSwitch : public AnfVisitor {
|
||||
class FloatTupleGetItemSwitch {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
Reset();
|
||||
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node);
|
||||
|
||||
auto fg = node->func_graph();
|
||||
if (Xs_.empty() || c_ == nullptr || fg == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto true_node = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), Xs_[1], c_});
|
||||
auto false_node = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), Xs_[2], c_});
|
||||
|
||||
return fg->NewCNode({NewValueNode(prim::kPrimSwitch), Xs_[0], true_node, false_node});
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
PatternNode<AnfNodePtr> cond, true_br, false_br, x;
|
||||
MATCH_REPLACE_IF(node,
|
||||
PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x),
|
||||
PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimTupleGetItem, true_br, x),
|
||||
PPrimitive(prim::kPrimTupleGetItem, false_br, x)),
|
||||
IsVNode(x.GetNode(node)));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Visit(const CNodePtr &cnode) override {
|
||||
// {prim::kPrimSwith, X1, X2, X3}
|
||||
if (!IsPrimitiveCNode(cnode, prim::kPrimSwitch) || cnode->size() != 4) {
|
||||
return;
|
||||
}
|
||||
|
||||
// copy X1, X2, X3
|
||||
auto &inputs = cnode->inputs();
|
||||
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_));
|
||||
}
|
||||
|
||||
void Visit(const ValueNodePtr &vnode) override { c_ = vnode; }
|
||||
|
||||
void Reset() {
|
||||
Xs_.clear();
|
||||
c_ = nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
AnfNodePtr c_{nullptr};
|
||||
std::vector<AnfNodePtr> Xs_{};
|
||||
};
|
||||
|
||||
// {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} =>
|
||||
// {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}}
|
||||
class FloatEnvGetItemSwitch : public AnfVisitor {
|
||||
class FloatEnvGetItemSwitch {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
is_match_ = false;
|
||||
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsCNode, IsNode, IsNode})(node);
|
||||
if (!is_match_) {
|
||||
return nullptr;
|
||||
}
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
PatternNode<AnfNodePtr> cond, true_br, false_br, x, x2;
|
||||
MATCH_REPLACE_IF(node,
|
||||
PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2),
|
||||
PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvGetItem, true_br, x, x2),
|
||||
PPrimitive(prim::kPrimEnvGetItem, false_br, x, x2)),
|
||||
IsNode(x.GetNode(node)) && IsNode(x2.GetNode(node)));
|
||||
|
||||
// {prim::kPrimEnvGetItem, {...}, X4, X5}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto sw_node = cnode->input(1)->cast<CNodePtr>();
|
||||
auto x4 = cnode->input(2);
|
||||
auto x5 = cnode->input(3);
|
||||
|
||||
is_match_ = false;
|
||||
AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsNode, IsNode})(sw_node);
|
||||
if (!is_match_) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// {prim::kPrimSwitch, X1, X2, X3}
|
||||
auto x1 = sw_node->input(1);
|
||||
auto x2 = sw_node->input(2);
|
||||
auto x3 = sw_node->input(3);
|
||||
|
||||
auto fg = node->func_graph();
|
||||
if (fg == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto true_node = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x2, x4, x5});
|
||||
auto false_node = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x3, x4, x5});
|
||||
|
||||
return fg->NewCNode({NewValueNode(prim::kPrimSwitch), x1, true_node, false_node});
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &) override { is_match_ = true; }
|
||||
|
||||
private:
|
||||
bool is_match_{false};
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
|
@ -173,79 +93,64 @@ AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfN
|
|||
} // namespace internal
|
||||
|
||||
// {{prim::kPrimSwitch, X, G1, G2}, Xs}
|
||||
class ConvertSwitchReplacement : public AnfVisitor {
|
||||
class ConvertSwitchReplacement {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Reset();
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode->size() < 1) {
|
||||
auto cnode_ = node->cast<CNodePtr>();
|
||||
if (cnode_->size() < 1) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// {prim::kPrimSwitch, X, G1, G2}
|
||||
AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode<FuncGraph>, IsValueNode<FuncGraph>})(cnode->input(0));
|
||||
if (g2_ == nullptr || g1_->output() == nullptr || g2_->output() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
// for switch replace method, only graphs without graph inside can be replaced
|
||||
for (auto &item : g1_->value_nodes()) {
|
||||
auto value_node = item.first;
|
||||
if (IsValueNode<FuncGraph>(value_node)) {
|
||||
return nullptr;
|
||||
auto node_ = cnode_->input(0);
|
||||
|
||||
PatternNode<AnfNodePtr> cond, true_br, false_br;
|
||||
|
||||
auto ConvertSwitchLambda = [&node_, &cond, &true_br, &false_br]() -> AnfNodePtr {
|
||||
auto g1_ = GetValueNode<FuncGraphPtr>(true_br.GetNode(node_));
|
||||
auto g2_ = GetValueNode<FuncGraphPtr>(false_br.GetNode(node_));
|
||||
auto x_ = cond.GetNode(node_);
|
||||
|
||||
// for switch replace method, only graphs without graph inside can be replaced
|
||||
for (auto &item : g1_->value_nodes()) {
|
||||
auto value_node = item.first;
|
||||
if (IsValueNode<FuncGraph>(value_node)) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &item : g2_->value_nodes()) {
|
||||
auto value_node = item.first;
|
||||
if (IsValueNode<FuncGraph>(value_node)) {
|
||||
return nullptr;
|
||||
for (auto &item : g2_->value_nodes()) {
|
||||
auto value_node = item.first;
|
||||
if (IsValueNode<FuncGraph>(value_node)) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto true_output = g1_->output()->abstract();
|
||||
auto false_output = g2_->output()->abstract();
|
||||
auto trans_g1 = internal::TransformGraphCondTrueBranchNodes(g1_, x_);
|
||||
auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_);
|
||||
auto true_output = g1_->output()->abstract();
|
||||
auto false_output = g2_->output()->abstract();
|
||||
auto trans_g1 = internal::TransformGraphCondTrueBranchNodes(g1_, x_);
|
||||
auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_);
|
||||
|
||||
std::vector<AnfNodePtr> params;
|
||||
auto fg = node->func_graph();
|
||||
auto cloned_g1 = InlineClone(trans_g1, fg, params);
|
||||
auto cloned_g2 = InlineClone(trans_g2, fg, params);
|
||||
auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg);
|
||||
return nnode;
|
||||
std::vector<AnfNodePtr> params;
|
||||
auto fg = node_->func_graph();
|
||||
auto cloned_g1 = InlineClone(trans_g1, fg, params);
|
||||
auto cloned_g2 = InlineClone(trans_g2, fg, params);
|
||||
auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg);
|
||||
|
||||
return nnode;
|
||||
};
|
||||
|
||||
MATCH_REPLACE_LAMBDA_IF(node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda,
|
||||
IsNode(cond.GetNode(node_)) && IsValueNode<FuncGraph>(true_br.GetNode(node_)) &&
|
||||
IsValueNode<FuncGraph>(false_br.GetNode(node_)));
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &node) override {
|
||||
if (x_ == nullptr) {
|
||||
x_ = node;
|
||||
return;
|
||||
}
|
||||
AnfVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void Visit(const ValueNodePtr &vnode) override {
|
||||
auto g = GetValueNode<FuncGraphPtr>(vnode);
|
||||
if (g1_ == nullptr) {
|
||||
g1_ = g;
|
||||
} else {
|
||||
g2_ = g;
|
||||
}
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
x_ = nullptr;
|
||||
g1_ = nullptr;
|
||||
g2_ = nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
AnfNodePtr x_{nullptr};
|
||||
FuncGraphPtr g1_{nullptr}, g2_{nullptr};
|
||||
};
|
||||
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue