Added OptimizerCaller class as a parent for optimization appliers

This commit is contained in:
Giancarlo Colmenares 2020-06-18 17:11:31 -04:00
parent ea87b6c443
commit cfc19a6274
4 changed files with 43 additions and 18 deletions

View File

@ -0,0 +1,29 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
#define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
#include "ir/anf.h"
#include "optimizer/opt.h"
namespace mindspore {
class OptimizerCaller {
public:
virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; }
};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_

View File

@ -14,11 +14,10 @@
* limitations under the License.
*/
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "ir/visitor.h"
namespace mindspore {
AnfNodePtr AnfVisitor::operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; }
void AnfVisitor::Visit(const AnfNodePtr &node) { node->accept(this); }
void AnfVisitor::Visit(const CNodePtr &cnode) {

View File

@ -18,14 +18,12 @@
#define MINDSPORE_CCSRC_IR_VISITOR_H_
#include <vector>
#include "ir/anf.h"
#include "optimizer/opt.h"
#include "ir/optimizer_caller.h"
namespace mindspore {
using VisitFuncType = std::function<void(const AnfNodePtr &)>;
class AnfVisitor {
class AnfVisitor : public OptimizerCaller {
public:
virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &);
virtual void Visit(const AnfNodePtr &);
virtual void Visit(const CNodePtr &);
virtual void Visit(const ValueNodePtr &);

View File

@ -20,22 +20,21 @@
#include <vector>
#include <algorithm>
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "operator/ops.h"
#include "ir/optimizer_caller.h"
#include "ir/pattern_matcher.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimSwitch, true, X, Y}
// {prim::kPrimSwitch, false, X, Y}
class SwitchSimplify {
class SwitchSimplify : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> cond, true_br, false_br;
auto SwitchSimplLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr {
auto cond_value_ = GetValue<bool>(GetValueNode(cond.GetNode(node)));
@ -54,9 +53,9 @@ class SwitchSimplify {
// {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} =>
// {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}}
class FloatTupleGetItemSwitch {
class FloatTupleGetItemSwitch : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> cond, true_br, false_br, x;
MATCH_REPLACE_IF(node,
PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x),
@ -69,9 +68,9 @@ class FloatTupleGetItemSwitch {
// {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} =>
// {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}}
class FloatEnvGetItemSwitch {
class FloatEnvGetItemSwitch : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> cond, true_br, false_br, x, x2;
MATCH_REPLACE_IF(node,
PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2),
@ -93,9 +92,9 @@ AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfN
} // namespace internal
// {{prim::kPrimSwitch, X, G1, G2}, Xs}
class ConvertSwitchReplacement {
class ConvertSwitchReplacement : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}