forked from mindspore-Ecosystem/mindspore
Added OptimizerCaller class as a parent for optimization appliers
This commit is contained in:
parent
ea87b6c443
commit
cfc19a6274
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
|
||||
#define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "optimizer/opt.h"
|
||||
|
||||
namespace mindspore {
|
||||
class OptimizerCaller {
|
||||
public:
|
||||
virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; }
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
|
|
@ -14,11 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ir/visitor.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/visitor.h"
|
||||
|
||||
namespace mindspore {
|
||||
AnfNodePtr AnfVisitor::operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; }
|
||||
void AnfVisitor::Visit(const AnfNodePtr &node) { node->accept(this); }
|
||||
|
||||
void AnfVisitor::Visit(const CNodePtr &cnode) {
|
||||
|
|
|
@ -18,14 +18,12 @@
|
|||
#define MINDSPORE_CCSRC_IR_VISITOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include "ir/anf.h"
|
||||
#include "optimizer/opt.h"
|
||||
#include "ir/optimizer_caller.h"
|
||||
|
||||
namespace mindspore {
|
||||
using VisitFuncType = std::function<void(const AnfNodePtr &)>;
|
||||
class AnfVisitor {
|
||||
class AnfVisitor : public OptimizerCaller {
|
||||
public:
|
||||
virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &);
|
||||
virtual void Visit(const AnfNodePtr &);
|
||||
virtual void Visit(const CNodePtr &);
|
||||
virtual void Visit(const ValueNodePtr &);
|
||||
|
|
|
@ -20,22 +20,21 @@
|
|||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "optimizer/irpass.h"
|
||||
#include "ir/visitor.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "operator/ops.h"
|
||||
#include "ir/optimizer_caller.h"
|
||||
#include "ir/pattern_matcher.h"
|
||||
#include "operator/ops.h"
|
||||
#include "optimizer/irpass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
// {prim::kPrimSwitch, true, X, Y}
|
||||
// {prim::kPrimSwitch, false, X, Y}
|
||||
class SwitchSimplify {
|
||||
class SwitchSimplify : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
PatternNode<AnfNodePtr> cond, true_br, false_br;
|
||||
auto SwitchSimplLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr {
|
||||
auto cond_value_ = GetValue<bool>(GetValueNode(cond.GetNode(node)));
|
||||
|
@ -54,9 +53,9 @@ class SwitchSimplify {
|
|||
|
||||
// {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} =>
|
||||
// {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}}
|
||||
class FloatTupleGetItemSwitch {
|
||||
class FloatTupleGetItemSwitch : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
PatternNode<AnfNodePtr> cond, true_br, false_br, x;
|
||||
MATCH_REPLACE_IF(node,
|
||||
PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x),
|
||||
|
@ -69,9 +68,9 @@ class FloatTupleGetItemSwitch {
|
|||
|
||||
// {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} =>
|
||||
// {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}}
|
||||
class FloatEnvGetItemSwitch {
|
||||
class FloatEnvGetItemSwitch : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
PatternNode<AnfNodePtr> cond, true_br, false_br, x, x2;
|
||||
MATCH_REPLACE_IF(node,
|
||||
PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2),
|
||||
|
@ -93,9 +92,9 @@ AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfN
|
|||
} // namespace internal
|
||||
|
||||
// {{prim::kPrimSwitch, X, G1, G2}, Xs}
|
||||
class ConvertSwitchReplacement {
|
||||
class ConvertSwitchReplacement : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue