forked from mindspore-Ecosystem/mindspore
!3973 enable convert anf fusion pass and optimize
Merge pull request !3973 from zhengjun10/master
This commit is contained in:
commit
d6547a46e7
|
@ -13,15 +13,107 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/gllo/common/gllo_utils.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "src/gllo/common/utils.h"
|
||||
#include "src/ir/primitive_t_value.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
|
||||
using PrimitiveTValuePtr = std::shared_ptr<mindspore::lite::PrimitiveTValue>;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr auto kAnfPrimitiveIndex = 0;
|
||||
bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
|
||||
}
|
||||
|
||||
bool IsRealKernel(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// parameter and value node is not a real kernel too
|
||||
if (!node->isa<CNode>()) {
|
||||
return true;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().empty()) {
|
||||
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString();
|
||||
}
|
||||
auto input = cnode->inputs()[0];
|
||||
bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) ||
|
||||
IsPrimitive(input, prim::kPrimTensorSummary) ||
|
||||
IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
|
||||
IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
|
||||
IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) ||
|
||||
IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial);
|
||||
return !is_virtual_node;
|
||||
}
|
||||
|
||||
ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
|
||||
if (utils::isa<int>(sexp)) {
|
||||
return NewValueNode(utils::cast<int>(sexp));
|
||||
}
|
||||
if (utils::isa<float>(sexp)) {
|
||||
return NewValueNode(utils::cast<float>(sexp));
|
||||
}
|
||||
if (utils::isa<bool>(sexp)) {
|
||||
return NewValueNode(utils::cast<bool>(sexp));
|
||||
}
|
||||
if (utils::isa<ValuePtr>(sexp)) {
|
||||
return NewValueNode(utils::cast<ValuePtr>(sexp));
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
|
||||
if (utils::isa<FuncGraphPtr>(graph)) {
|
||||
return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
|
||||
}
|
||||
if (utils::isa<VarPtr>(graph)) {
|
||||
return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
|
||||
if (utils::isa<VarPtr>(graph)) {
|
||||
MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
|
||||
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
|
||||
}
|
||||
if (utils::isa<FuncGraphPtr>(graph)) {
|
||||
MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
|
||||
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
|
||||
}
|
||||
MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
|
||||
bool multigraph) {
|
||||
MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
|
||||
std::vector<AnfNodePtr> input_nodes;
|
||||
const auto &tuple = utils::cast<VectorRef>(sexp);
|
||||
if (multigraph && utils::isa<VarPtr>(graph)) {
|
||||
for (auto &x : tuple) {
|
||||
AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
|
||||
input_nodes.push_back(node);
|
||||
}
|
||||
VarPtr var_ptr = utils::cast<VarPtr>(graph);
|
||||
return std::make_shared<CNode>(input_nodes, var_ptr);
|
||||
}
|
||||
|
||||
for (auto &x : tuple) {
|
||||
AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
|
||||
input_nodes.push_back(node);
|
||||
}
|
||||
return CreateCNodeWithGraph(input_nodes, graph);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool AnfEqual(const BaseRef &a, const BaseRef &b) {
|
||||
if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
|
||||
auto a_node = utils::cast<AnfNodePtr>(a);
|
||||
|
@ -64,15 +156,15 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) {
|
|||
}
|
||||
|
||||
if (utils::isa<lite::PrimitiveTValue>(a_value_ptr) && utils::isa<lite::PrimitiveTValue>(b_value_ptr)) {
|
||||
auto a_obj = (lite::PrimitiveTValue *)(a_value_ptr.get());
|
||||
auto b_obj = (lite::PrimitiveTValue *)(b_value_ptr.get());
|
||||
auto a_obj = (lite::PrimitiveTValue *) (a_value_ptr.get());
|
||||
auto b_obj = (lite::PrimitiveTValue *) (b_value_ptr.get());
|
||||
return (*a_obj) == (*b_obj);
|
||||
} else {
|
||||
return (*a_value_ptr) == (*b_value_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (a.m_ptr->isa<lite::PrimitiveTValue>()) {
|
||||
if (a.m_ptr->isa<lite::PrimitiveTValue>() && b.m_ptr->isa<lite::PrimitiveTValue>()) {
|
||||
auto a_value_node_ptr = a.m_ptr->cast<PrimitiveTValuePtr>();
|
||||
auto b_value_node_ptr = b.m_ptr->cast<PrimitiveTValuePtr>();
|
||||
return a_value_node_ptr->GetPrimitiveT()->value.type == b_value_node_ptr->GetPrimitiveT()->value.type;
|
||||
|
@ -89,70 +181,8 @@ bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
|
|||
return a.type() == b.type();
|
||||
}
|
||||
|
||||
namespace {
|
||||
ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
|
||||
if (utils::isa<int>(sexp)) {
|
||||
return NewValueNode(utils::cast<int>(sexp));
|
||||
}
|
||||
if (utils::isa<float>(sexp)) {
|
||||
return NewValueNode(utils::cast<float>(sexp));
|
||||
}
|
||||
if (utils::isa<bool>(sexp)) {
|
||||
return NewValueNode(utils::cast<bool>(sexp));
|
||||
}
|
||||
if (utils::isa<ValuePtr>(sexp)) {
|
||||
return NewValueNode(utils::cast<ValuePtr>(sexp));
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
|
||||
if (utils::isa<FuncGraphPtr>(graph)) {
|
||||
return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
|
||||
}
|
||||
if (utils::isa<VarPtr>(graph)) {
|
||||
return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
|
||||
if (utils::isa<VarPtr>(graph)) {
|
||||
// MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
|
||||
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
|
||||
}
|
||||
if (utils::isa<FuncGraphPtr>(graph)) {
|
||||
// MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
|
||||
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
|
||||
}
|
||||
MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
|
||||
bool multigraph) {
|
||||
// MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
|
||||
std::vector<AnfNodePtr> input_nodes;
|
||||
const auto &tuple = utils::cast<VectorRef>(sexp);
|
||||
if (multigraph && utils::isa<VarPtr>(graph)) {
|
||||
for (auto &x : tuple) {
|
||||
AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
|
||||
input_nodes.push_back(node);
|
||||
}
|
||||
VarPtr var_ptr = utils::cast<VarPtr>(graph);
|
||||
return std::make_shared<CNode>(input_nodes, var_ptr);
|
||||
}
|
||||
|
||||
for (auto &x : tuple) {
|
||||
AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
|
||||
input_nodes.push_back(node);
|
||||
}
|
||||
return CreateCNodeWithGraph(input_nodes, graph);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
|
||||
// MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
|
||||
MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
|
||||
MS_EXCEPTION_IF_NULL(primitive_vars);
|
||||
if (utils::isa<VectorRef>(sexp)) {
|
||||
return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
|
||||
|
@ -176,6 +206,38 @@ AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap
|
|||
return value_node;
|
||||
}
|
||||
|
||||
bool IsRealCNodeKernel(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// parameter and value node is not a real cnode kernel
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
// return considered as a real node
|
||||
if (CheckPrimitiveType(node, prim::kPrimReturn)) {
|
||||
return true;
|
||||
}
|
||||
return IsRealKernel(node);
|
||||
}
|
||||
bool IsGraphKernel(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// graph kernel should be a real cnode kernel.
|
||||
if (!IsRealCNodeKernel(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto input = cnode->input(kAnfPrimitiveIndex);
|
||||
// graph kernel should has func_graph as first input.
|
||||
if (!IsValueNode<FuncGraph>(input)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(input);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||
}
|
||||
|
||||
void CheckIfFuncGraphIsNull(const FuncGraphPtr &graph) {
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The graph is null.";
|
|
@ -14,22 +14,21 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_
|
||||
|
||||
#include <mindspore/lite/src/ir/primitive_t_value.h>
|
||||
#include <memory>
|
||||
#include "src/ir/primitive_t_value.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "src/gllo/common/pattern_engine.h"
|
||||
#include "backend/optimizer/common/pattern_engine.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/param_value_lite.h"
|
||||
|
||||
using PrimitiveTValuePtr = std::shared_ptr<mindspore::lite::PrimitiveTValue>;
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
||||
bool AnfEqual(const BaseRef &a, const BaseRef &b);
|
||||
|
||||
bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b);
|
||||
|
@ -37,6 +36,10 @@ bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b);
|
|||
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
|
||||
bool multigraph = false);
|
||||
|
||||
bool IsRealCNodeKernel(const AnfNodePtr &node);
|
||||
|
||||
bool IsGraphKernel(const AnfNodePtr &node);
|
||||
|
||||
void CheckIfFuncGraphIsNull(const FuncGraphPtr &graph);
|
||||
|
||||
void CheckIfAnfNodeIsNull(const AnfNodePtr &node);
|
||||
|
@ -61,4 +64,4 @@ bool IsParamNode(const BaseRef &n);
|
|||
bool IsConvNode(const BaseRef &n);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_
|
|
@ -13,7 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/gllo/common/node_pass.h"
|
||||
#include "backend/optimizer/common/node_pass.h"
|
||||
|
||||
#include <unordered_set>
|
||||
#include <deque>
|
||||
|
@ -22,6 +22,7 @@
|
|||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/manager.h"
|
||||
#include "src/gllo/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -54,6 +55,9 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) {
|
|||
MS_EXCEPTION_IF_NULL(const_func_graph);
|
||||
todo.push_back(const_func_graph->output());
|
||||
} else if (new_node && new_node->isa<CNode>()) {
|
||||
if (IsGraphKernel(new_node)) {
|
||||
todo.push_back(new_node);
|
||||
}
|
||||
auto cnode = new_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto inputs = cnode->inputs();
|
||||
|
|
|
@ -1,36 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_NODE_PASS_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_COMMON_NODE_PASS_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "src/gllo/common/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
// @brief ANF Node level optimization base pass
|
||||
class NodePass : public Pass {
|
||||
public:
|
||||
explicit NodePass(const std::string &name) : Pass(name) {}
|
||||
~NodePass() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) final;
|
||||
virtual AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0;
|
||||
};
|
||||
using NodePassPtr = std::shared_ptr<NodePass>;
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_NODE_PASS_H_
|
|
@ -23,8 +23,7 @@
|
|||
#include <utility>
|
||||
#include <initializer_list>
|
||||
|
||||
#include "src/gllo/common/pass_manager.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/pass_manager.h"
|
||||
#include "ir/manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -26,9 +26,9 @@
|
|||
#include "ir/graph_utils.h"
|
||||
#include "src/common/utils.h"
|
||||
|
||||
#include "src/gllo/common/pass_manager.h"
|
||||
#include "src/gllo/common/pattern_engine.h"
|
||||
#include "src/gllo/common/utils.h"
|
||||
#include "backend/optimizer/common/pass_manager.h"
|
||||
#include "backend/optimizer/common/pattern_engine.h"
|
||||
#include "src/gllo/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/gllo/common/pass_manager.h"
|
||||
#include "backend/optimizer/common/pass_manager.h"
|
||||
|
||||
#include <sys/time.h>
|
||||
#include <unordered_set>
|
||||
|
|
|
@ -1,61 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_PASS_MANAGER_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_COMMON_PASS_MANAGER_H_
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "src/gllo/common/pass.h"
|
||||
#include "src/gllo/common/node_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
// @brief For optimization passes management
|
||||
class PassManager {
|
||||
public:
|
||||
explicit PassManager(const std::string &name = "pm", bool run_only_once = true)
|
||||
: name_(name), passes_{}, run_only_once_(run_only_once) {}
|
||||
virtual ~PassManager() = default;
|
||||
// Get all the passes added by AddPass
|
||||
const std::vector<PassPtr> &Passes() const;
|
||||
// Add graph pass, the pass object will be freed when pass manager freed.
|
||||
void AddPass(const PassPtr &pass);
|
||||
// Run passes added in pass manager on the input graph
|
||||
// @param [inout] graph The graph to be optimized
|
||||
// @return true, graph changed
|
||||
// @return false, graph not changed
|
||||
bool Run(const FuncGraphPtr &func_graph) const;
|
||||
// Run the given graph passes on the input graph
|
||||
// @param [inout] graph The graph to be optimized
|
||||
// @param [in] passes The given graph passes
|
||||
// @return true, graph changed
|
||||
// @return false, graph not changed
|
||||
bool Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr> &passes) const;
|
||||
std::string name() const { return name_; }
|
||||
|
||||
private:
|
||||
const std::string name_;
|
||||
std::vector<PassPtr> passes_;
|
||||
bool run_only_once_;
|
||||
};
|
||||
using PassManagerPtr = std::shared_ptr<PassManager>;
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_PASS_MANAGER_H_
|
|
@ -1,365 +0,0 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/gllo/common/pattern_engine.h"
|
||||
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
|
||||
#include "ir/func_graph.h"
|
||||
#include "mindspore/core/ir/primitive.h"
|
||||
#include "utils/info.h"
|
||||
#include "ir/anf.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
#include "utils/overload.h"
|
||||
|
||||
|
||||
namespace mindspore {
|
||||
static int GetNextTag() {
|
||||
static int kID = 0;
|
||||
return kID++;
|
||||
}
|
||||
|
||||
void Var::EnsureTag() {
|
||||
if (tag_.length() == 0) {
|
||||
std::ostringstream buffer;
|
||||
buffer << "_" << GetNextTag();
|
||||
tag_ = buffer.str();
|
||||
}
|
||||
}
|
||||
|
||||
bool operator==(const VarPtr &lhs, const VarPtr &rhs) {
|
||||
if (lhs->isa<CondVar>() && rhs->isa<CondVar>()) {
|
||||
CondVarPtr v1 = dyn_cast<CondVar>(lhs);
|
||||
CondVarPtr v2 = dyn_cast<CondVar>(rhs);
|
||||
return *v1 == *v2;
|
||||
}
|
||||
|
||||
if (lhs->isa<SeqVar>() && rhs->isa<SeqVar>()) {
|
||||
SVarPtr v1 = dyn_cast<SeqVar>(lhs);
|
||||
SVarPtr v2 = dyn_cast<SeqVar>(rhs);
|
||||
return *v1 == *v2;
|
||||
}
|
||||
return (*lhs == *rhs);
|
||||
}
|
||||
|
||||
std::string SeqVar::ToString() const {
|
||||
std::ostringstream buffer;
|
||||
buffer << "SeqVar(" << tag() << ", " << subvar_->ToString() << ")";
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, const VarPtr &var) {
|
||||
if (var == nullptr) {
|
||||
os << "";
|
||||
} else {
|
||||
os << var->ToString();
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
template <>
|
||||
|
||||
std::ostream &operator<<<VarPtr, BaseRef>(std::ostream &os, const Equiv &equiv) {
|
||||
os << "[Equiv]"
|
||||
<< "\n";
|
||||
for (auto &equiv_item : equiv) {
|
||||
auto k = equiv_item.first;
|
||||
os << k << ":";
|
||||
BaseRef x = equiv_item.second;
|
||||
if (utils::isa<AnfNodePtr>(x)) {
|
||||
auto node = utils::cast<AnfNodePtr>(x);
|
||||
os << "TypeString[" << node->type_name() << "]";
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
os << "IsValueNodeGraph ";
|
||||
}
|
||||
os << "type " << node->type_name();
|
||||
if (node->isa<ValueNode>()) {
|
||||
os << " value " << GetValueNode(node);
|
||||
}
|
||||
os << " addr: " << node;
|
||||
} else if (utils::isa<Named>(x)) {
|
||||
os << "Named " << x.ToString().c_str();
|
||||
} else if (utils::isa<VarPtr>(x)) {
|
||||
os << "TypeString[Var]";
|
||||
os << utils::cast<VarPtr>(x);
|
||||
} else if (utils::isa<FuncGraphPtr>(x)) {
|
||||
os << "TypeString[Graph]";
|
||||
}
|
||||
os << "\n";
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
static BaseRef GetVar(const BaseRef &x) {
|
||||
// MS_LOG(DEBUG) << "getVar start :%s" + x.ToString();
|
||||
if (utils::isa<AnfNodePtr>(x)) {
|
||||
auto node = utils::cast<AnfNodePtr>(x);
|
||||
// MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]";
|
||||
if (node->isa<VarNode>()) {
|
||||
// MS_LOG(DEBUG) << "IsVarNode " + node->cast<VarNodePtr>()->var_->ToString();
|
||||
return node->cast<VarNodePtr>()->var_;
|
||||
}
|
||||
// if (node->isa<ValueNode>()) {
|
||||
// MS_LOG(DEBUG) << "value " + GetValueNode(node)->ToString() + " addr: " + node->ToString();
|
||||
// } else {
|
||||
// MS_LOG(DEBUG) << "type " + node->type_name();
|
||||
// }
|
||||
// } else if (utils::isa<Named>(x)) {
|
||||
// MS_LOG(DEBUG) << "Named " + x.ToString();
|
||||
// } else if (utils::isa<VectorRef>(x)) {
|
||||
// MS_LOG(DEBUG) << "VectorRef";
|
||||
// } else if (utils::isa<VarPtr>(x)) {
|
||||
// MS_LOG(DEBUG) << "TypeString[Var] " + x.ToString();
|
||||
}
|
||||
// MS_LOG(DEBUG) << "GetVar end: " + x.ToString();
|
||||
return x;
|
||||
}
|
||||
|
||||
EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) {
|
||||
MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString();
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
if (utils::isa<VarPtr>(pattern)) {
|
||||
VarPtr var = utils::cast<VarPtr>(pattern);
|
||||
if (var->matches(expr)) {
|
||||
(*equiv)[var] = expr;
|
||||
MS_LOG(DEBUG) << "pattern is var match: " + pattern.ToString() + ", " + expr.ToString();
|
||||
return equiv;
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool PatternEngine::ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern,
|
||||
VectorRef *const values_expr) const {
|
||||
MS_EXCEPTION_IF_NULL(values_expr);
|
||||
if (utils::isa<SeqPtr>(pattern_ref)) {
|
||||
*values_pattern = pattern_ref;
|
||||
*values_expr = expr_ref;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool PatternEngine::ToVector(const BaseRef &pattern_ref, const BaseRef &expr_ref, VectorRef *const values_pattern,
|
||||
VectorRef *const values_expr) const {
|
||||
MS_EXCEPTION_IF_NULL(values_expr);
|
||||
// visitor to visite the list
|
||||
auto appender_pattern = [](VectorRef &values) {
|
||||
std::function<BaseRef(const BaseRef &)> fn = [&](const BaseRef &u) {
|
||||
values.push_back(GetVar(u));
|
||||
return u;
|
||||
};
|
||||
return fn;
|
||||
};
|
||||
|
||||
visitor_->SetFn(appender_pattern(*values_pattern));
|
||||
// MS_LOG(DEBUG) << "visit pattern_ref";
|
||||
bool success = visitor_->Visit(pattern_ref, nullptr);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto appender_expr = [](VectorRef &values) {
|
||||
std::function<BaseRef(const BaseRef &)> fn = [&](const BaseRef &u) {
|
||||
values.push_back(u);
|
||||
return u;
|
||||
};
|
||||
return fn;
|
||||
};
|
||||
|
||||
visitor_->SetFn(appender_expr(*values_expr));
|
||||
// MS_LOG(DEBUG) << "visit expr_ref";
|
||||
return visitor_->Visit(expr_ref, nullptr);
|
||||
}
|
||||
|
||||
static int GetSVarStartIndex(const VectorRef &values) {
|
||||
int index = -1;
|
||||
int count = 0;
|
||||
for (auto &value : values) {
|
||||
if (utils::isa<VarPtr>(value) && utils::cast<VarPtr>(value)->isa<SeqVar>()) {
|
||||
if (index != -1) {
|
||||
// MS_LOG(DEBUG) << "Multiple SVars in sequence";
|
||||
return kInvalidVarIndex;
|
||||
}
|
||||
index = count;
|
||||
}
|
||||
count++;
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
void UpdateEquivMap(const VectorRef &values_pattern, const BaseRef &expr_ref, const PrimitiveVarMap &primitive_vars,
|
||||
EquivPtr equiv) {
|
||||
if (equiv == nullptr || values_pattern.empty() || !utils::isa<AnfNodePtr>(values_pattern[0]) ||
|
||||
!utils::isa<AnfNodePtr>(expr_ref)) {
|
||||
return;
|
||||
}
|
||||
auto real_node = utils::cast<AnfNodePtr>(expr_ref);
|
||||
MS_EXCEPTION_IF_NULL(real_node);
|
||||
if (!real_node->isa<CNode>()) {
|
||||
return;
|
||||
}
|
||||
auto prim_node = utils::cast<AnfNodePtr>(values_pattern[0]);
|
||||
MS_EXCEPTION_IF_NULL(prim_node);
|
||||
if (!IsValueNode<Primitive>(prim_node)) {
|
||||
return;
|
||||
}
|
||||
ValuePtr value = GetValueNode(prim_node);
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
auto prim = value->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto iter = primitive_vars.find(prim);
|
||||
if (iter == primitive_vars.end()) {
|
||||
return;
|
||||
}
|
||||
(*equiv)[iter->second] = real_node;
|
||||
}
|
||||
|
||||
EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr,
|
||||
const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const {
|
||||
int svar_index = GetSVarStartIndex(values_pattern);
|
||||
if (svar_index == kInvalidVarIndex) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
size_t values_pattern_len = values_pattern.size();
|
||||
size_t values_expr_len = values_expr.size();
|
||||
|
||||
if (svar_index == -1) {
|
||||
if (values_pattern_len != values_expr_len) {
|
||||
// MS_LOG(DEBUG) << "Structures of differing size: pattern len " << values_pattern_len << ",
|
||||
// expr len " << values_expr_len;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
if (values_expr_len < values_pattern_len - 1) {
|
||||
MS_LOG(DEBUG) << "invalid size: pattern len " << values_pattern_len << ", expr len " << values_expr_len;
|
||||
return nullptr;
|
||||
}
|
||||
size_t diff = values_expr_len - values_pattern_len + 1;
|
||||
for (size_t i = 0; i < values_pattern_len; i++) {
|
||||
size_t expr_i = i;
|
||||
if (svar_index != -1 && i == IntToSize(svar_index)) {
|
||||
auto seq =
|
||||
std::vector<BaseRef>(values_expr.begin() + svar_index, values_expr.begin() + svar_index + SizeToInt(diff));
|
||||
equiv = Match(values_pattern[svar_index], seq, primitive_vars, equiv);
|
||||
} else {
|
||||
if (svar_index != -1 && i > IntToSize(svar_index)) {
|
||||
expr_i = i + diff - 1;
|
||||
}
|
||||
equiv = Match(values_pattern[i], values_expr[expr_i], primitive_vars, equiv);
|
||||
}
|
||||
if (equiv == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return equiv;
|
||||
}
|
||||
|
||||
EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars,
|
||||
EquivPtr equiv) const {
|
||||
MS_LOG(DEBUG) << "-----[in Match]";
|
||||
// MS_LOG(DEBUG) << "GetVar w";
|
||||
BaseRef pattern_ref = GetVar(pattern);
|
||||
// MS_LOG(DEBUG) << "GetVar v";
|
||||
BaseRef expr_ref = expr;
|
||||
|
||||
if (equiv == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Equiv pointer is null";
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "Pattern ref " + pattern_ref.ToString() + ", expr ref" + expr_ref.ToString();
|
||||
// 1. if pattern_ref is var and already in equiv, replace it.
|
||||
if (utils::isa<VarPtr>(pattern_ref)) {
|
||||
VarPtr var = utils::cast<VarPtr>(pattern_ref);
|
||||
auto iter = equiv->find(var);
|
||||
if (iter != equiv->end()) {
|
||||
pattern_ref = iter->second;
|
||||
}
|
||||
}
|
||||
|
||||
// 2. check equal
|
||||
if (eq_(pattern_ref, expr_ref)) {
|
||||
return equiv;
|
||||
}
|
||||
|
||||
// 3. match var
|
||||
EquivPtr ret_equiv = MatchOnVar(pattern_ref, expr_ref, equiv);
|
||||
if (ret_equiv) {
|
||||
return ret_equiv;
|
||||
}
|
||||
|
||||
// 4. here the type can be std:vector, std:list,
|
||||
// or cnode.
|
||||
if (!type_eq_(pattern_ref, expr_ref)) {
|
||||
MS_LOG(DEBUG) << "Type mismatch";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// 5. transfer the Containers by visitor to std::vector
|
||||
VectorRef values_pattern;
|
||||
VectorRef values_expr;
|
||||
if (!ToVector(pattern_ref, expr_ref, &values_pattern, &values_expr)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// 6. if any svar in both side, find the SeqVar index,
|
||||
// try to pack the Var s in std::vector to a Seq and match elements one by one.
|
||||
// check svar
|
||||
equiv = AlignSVar(values_pattern, values_expr, primitive_vars, equiv);
|
||||
UpdateEquivMap(values_pattern, expr_ref, primitive_vars, equiv);
|
||||
return equiv;
|
||||
}
|
||||
|
||||
BaseRef PatternEngine::Replace(const BaseRef &pattern, const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
MS_LOG(DEBUG) << "-----[in Replace]";
|
||||
BaseRef ref = GetVar(pattern);
|
||||
BaseRef out;
|
||||
bool is_match = false;
|
||||
|
||||
// w is var
|
||||
if (utils::isa<VarPtr>(ref)) {
|
||||
const VarPtr &var = utils::cast<VarPtr>(ref);
|
||||
auto iter = equiv->find(var);
|
||||
if (iter != equiv->end()) {
|
||||
out = iter->second;
|
||||
is_match = true;
|
||||
}
|
||||
}
|
||||
if (is_match) {
|
||||
return out;
|
||||
}
|
||||
|
||||
// visitor to visit the list
|
||||
std::function<BaseRef(BaseRef)> fn = [&, this, equiv](const BaseRef &u) { return Replace(u, equiv); };
|
||||
|
||||
visitor_->SetFn(fn);
|
||||
BaseRef visit_out;
|
||||
if (!visitor_->Visit(pattern, &visit_out)) {
|
||||
return pattern;
|
||||
}
|
||||
return visit_out;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
@ -1,203 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_PATTERN_ENGINE_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_COMMON_PATTERN_ENGINE_H_
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include <initializer_list>
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <stdexcept>
|
||||
#include <list>
|
||||
#include <utility>
|
||||
|
||||
#include "src/gllo/common/visit.h"
|
||||
#include "mindspore/core/base/base.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "base/base_ref.h"
|
||||
|
||||
namespace mindspore {
|
||||
class CondVar;
|
||||
class SeqVar;
|
||||
using CondVarPtr = std::shared_ptr<CondVar>;
|
||||
using SVarPtr = std::shared_ptr<SeqVar>;
|
||||
const int kInvalidVarIndex = -2;
|
||||
|
||||
using ConditionFunc = std::function<bool(const BaseRef &)>;
|
||||
|
||||
// Base wildcard variable which could match any anf node.
|
||||
class Var : public Base {
|
||||
friend class VarHasher;
|
||||
|
||||
public:
|
||||
explicit Var(std::string tag = "") : tag_(std::move(tag)), primitive_(nullptr) { EnsureTag(); }
|
||||
explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) {
|
||||
EnsureTag();
|
||||
}
|
||||
Var(const Var &other) : Base(other), tag_(other.tag_) {}
|
||||
virtual Var &operator=(const Var &other) {
|
||||
if (&other == this) {
|
||||
return *this;
|
||||
}
|
||||
this->tag_ = other.tag_;
|
||||
return *this;
|
||||
}
|
||||
~Var() override = default;
|
||||
MS_DECLARE_PARENT(Var, Base);
|
||||
|
||||
virtual bool matches(const BaseRef &) { return true; }
|
||||
|
||||
virtual bool operator==(const Var &other) const { return tag_ == other.tag_; }
|
||||
bool operator!=(const Var &other) const { return !(&other == this); }
|
||||
|
||||
std::string tag() const { return tag_; }
|
||||
PrimitivePtr primitive() const { return primitive_; }
|
||||
std::string ToString() const override {
|
||||
std::ostringstream buffer;
|
||||
buffer << "Var(" << tag_ << ")";
|
||||
return buffer.str();
|
||||
}
|
||||
std::size_t hash() const override { return std::hash<std::string>()(tag_); }
|
||||
|
||||
protected:
|
||||
void EnsureTag();
|
||||
|
||||
std::string tag_;
|
||||
PrimitivePtr primitive_;
|
||||
};
|
||||
|
||||
// VarNode means variable node, a subclass of AnfNode
|
||||
class VarNode : public AnfNode {
|
||||
public:
|
||||
VarNode(const VarPtr &value, const FuncGraphPtr &func_graph) : AnfNode(func_graph), var_(value) {}
|
||||
~VarNode() override = default;
|
||||
MS_DECLARE_PARENT(VarNode, AnfNode);
|
||||
|
||||
const VarPtr var_;
|
||||
};
|
||||
using VarNodePtr = std::shared_ptr<VarNode>;
|
||||
|
||||
class VarHasher {
|
||||
public:
|
||||
std::size_t operator()(const Var &var) const { return var.hash(); }
|
||||
};
|
||||
|
||||
// Condition Var, match an anf node when condition function return true.
|
||||
class CondVar : public Var {
|
||||
public:
|
||||
explicit CondVar(const ConditionFunc &cond) : cond_fn_(cond) {}
|
||||
~CondVar() override = default;
|
||||
MS_DECLARE_PARENT(CondVar, Var);
|
||||
bool matches(const BaseRef &value) override {
|
||||
// MS_LOG(DEBUG) << "CondVarPtr match: " + value.ToString();
|
||||
if (utils::isa<Var>(value)) {
|
||||
return false;
|
||||
}
|
||||
return cond_fn_(value);
|
||||
}
|
||||
ConditionFunc cond_fn_;
|
||||
};
|
||||
|
||||
using Seq = VectorRef;
|
||||
using SeqPtr = std::shared_ptr<Seq>;
|
||||
|
||||
// Sequence Var which could match multiple consecutive input nodes of a CNode.
|
||||
class SeqVar : public Var {
|
||||
public:
|
||||
SeqVar() : subvar_(std::make_shared<Var>()) {}
|
||||
~SeqVar() override = default;
|
||||
MS_DECLARE_PARENT(SeqVar, Var);
|
||||
explicit SeqVar(const VarPtr subvar) : subvar_(subvar) {}
|
||||
bool matches(const BaseRef &value) override {
|
||||
// match Seq.
|
||||
if (utils::isa<Seq>(value)) {
|
||||
const Seq &seq = utils::cast<Seq>(value);
|
||||
return std::all_of(seq.begin(), seq.end(), [this](const BaseRef &v) {
|
||||
auto eq = subvar_->matches(v);
|
||||
return eq;
|
||||
});
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool operator==(const SeqVar &other) const { return *subvar_ == *other.subvar_; }
|
||||
std::string ToString() const override;
|
||||
|
||||
private:
|
||||
VarPtr subvar_;
|
||||
};
|
||||
|
||||
bool operator==(const VarPtr &lhs, const VarPtr &rhs);
|
||||
|
||||
inline bool operator!=(const VarPtr &lhs, const VarPtr &rhs) { return !(lhs == rhs); }
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, const VarPtr &var);
|
||||
|
||||
using Equiv = std::map<VarPtr, BaseRef>;
|
||||
using EquivPtr = std::shared_ptr<Equiv>;
|
||||
using PrimitiveVarMap = std::unordered_map<PrimitivePtr, VarPtr>;
|
||||
using PrimitiveVarMapPtr = std::shared_ptr<PrimitiveVarMap>;
|
||||
|
||||
inline bool DefaultTypeEq(const BaseRef &x, const BaseRef &y) { return x.type() == y.type(); }
|
||||
|
||||
class PatternEngine {
|
||||
public:
|
||||
PatternEngine(const std::shared_ptr<Visitor> &visitor,
|
||||
const std::function<bool(const BaseRef &, const BaseRef &)> &eq,
|
||||
const std::function<bool(const BaseRef &, const BaseRef &)> &type_eq = DefaultTypeEq)
|
||||
: visitor_(visitor), eq_(eq), type_eq_(type_eq) {}
|
||||
~PatternEngine() = default;
|
||||
|
||||
EquivPtr Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars,
|
||||
EquivPtr equiv) const;
|
||||
// Replace pattern with equivalent
|
||||
BaseRef Replace(const BaseRef &pattern, const EquivPtr &equiv) const;
|
||||
|
||||
private:
|
||||
EquivPtr AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr,
|
||||
const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const;
|
||||
bool ToVector(const BaseRef &pattern, const BaseRef &expr, VectorRef *const values_pattern,
|
||||
VectorRef *const values_expr) const;
|
||||
bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern,
|
||||
VectorRef *const values_expr) const;
|
||||
std::shared_ptr<Visitor> visitor_;
|
||||
std::function<bool(const BaseRef &, const BaseRef &)> eq_;
|
||||
std::function<bool(const BaseRef &, const BaseRef &)> type_eq_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
namespace std {
|
||||
using mindspore::ERROR;
|
||||
using mindspore::LogStream;
|
||||
using mindspore::NoExceptionType;
|
||||
template <>
|
||||
struct hash<mindspore::VarPtr> {
|
||||
std::size_t operator()(const mindspore::VarPtr var) const {
|
||||
if (var == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid var ptr";
|
||||
return 0;
|
||||
}
|
||||
return std::hash<std::string>{}(var->tag());
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_PATTERN_ENGINE_H_
|
||||
|
|
@ -1,165 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "src/gllo/common/visit.h"
|
||||
#include "src/gllo/common/pattern_engine.h"
|
||||
#include "utils/any.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
|
||||
namespace mindspore {
|
||||
bool CheckIfNeedExpand(const std::vector<BaseRef> &list) {
|
||||
return std::any_of(list.begin(), list.end(), [](const BaseRef &any) { return utils::isa<Seq>(any); });
|
||||
}
|
||||
|
||||
std::shared_ptr<VectorRef> ExpandList(const std::vector<BaseRef> &list) {
|
||||
std::shared_ptr<VectorRef> new_list = std::make_shared<VectorRef>();
|
||||
for (auto &item : list) {
|
||||
if (utils::isa<Seq>(item)) {
|
||||
const Seq &seq = utils::cast<Seq>(item);
|
||||
new_list->insert(new_list->end(), seq.begin(), seq.end());
|
||||
} else {
|
||||
new_list->push_back(item);
|
||||
}
|
||||
}
|
||||
return new_list;
|
||||
}
|
||||
|
||||
bool DefaultVisitor::Visit(const VectorRef &v_any, BaseRef *const visit_out) const {
|
||||
std::vector<BaseRef> out;
|
||||
(void)std::transform(v_any.begin(), v_any.end(), std::back_inserter(out),
|
||||
[this](const BaseRef &item) { return fn_(item); });
|
||||
if (visit_out != nullptr) {
|
||||
*visit_out = ExpandList(out);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DefaultVisitor::Visit(const BaseRef &any, BaseRef *const visit_out) const {
|
||||
if (utils::isa<Seq>(any)) {
|
||||
return Visit(utils::cast<Seq>(any), visit_out);
|
||||
} else if (utils::isa<AnfNodePtr>(any)) {
|
||||
auto nodeptr = utils::cast<AnfNodePtr>(any);
|
||||
AnfNodePtr output;
|
||||
AnfNodePtr *p_output = &output;
|
||||
if (visit_out == nullptr) {
|
||||
p_output = nullptr;
|
||||
}
|
||||
Visit(nodeptr, fn_, p_output);
|
||||
if (visit_out != nullptr) {
|
||||
*visit_out = output;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
MS_LOG(DEBUG) << "VisitError, not support type to Visit: " + any.ToString();
|
||||
return false;
|
||||
}
|
||||
|
||||
void DefaultVisitor::Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr *output) const {
|
||||
if (node->isa<CNode>()) {
|
||||
Visit(node->cast<CNodePtr>(), fn, output);
|
||||
return;
|
||||
}
|
||||
|
||||
if (node->isa<ValueNode>()) {
|
||||
Visit(node->cast<ValueNodePtr>(), fn, output);
|
||||
return;
|
||||
}
|
||||
|
||||
if (output != nullptr) {
|
||||
*output = node;
|
||||
}
|
||||
}
|
||||
|
||||
void DefaultVisitor::Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr *output) const {
|
||||
// if output is nullptr, it's not required to make the new CNode node.
|
||||
if (output == nullptr) {
|
||||
for (auto &inp : cnode->inputs()) {
|
||||
(void)fn(inp);
|
||||
}
|
||||
|
||||
if (cnode->func_graph() != nullptr) {
|
||||
(void)fn(cnode->func_graph());
|
||||
} else {
|
||||
(void)fn(cnode->func_graph_as_var());
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> new_inputs;
|
||||
std::vector<BaseRef> after_cnode_fn;
|
||||
std::shared_ptr<VectorRef> out;
|
||||
(void)std::transform(cnode->inputs().begin(), cnode->inputs().end(), std::back_inserter(after_cnode_fn), fn);
|
||||
if (CheckIfNeedExpand(after_cnode_fn)) {
|
||||
out = ExpandList(after_cnode_fn);
|
||||
}
|
||||
|
||||
std::vector<BaseRef> &outs = after_cnode_fn;
|
||||
if (out != nullptr) {
|
||||
outs = out->elements();
|
||||
}
|
||||
|
||||
for (auto &any_item : outs) {
|
||||
if (!utils::isa<AnfNodePtr>(any_item)) {
|
||||
MS_LOG(EXCEPTION) << "VisitError, fn not return the same type AnfNodePtr";
|
||||
}
|
||||
new_inputs.push_back(utils::cast<AnfNodePtr>(any_item));
|
||||
}
|
||||
|
||||
BaseRef any_fg;
|
||||
AnfNodePtr new_cnode = nullptr;
|
||||
if (cnode->func_graph() != nullptr) {
|
||||
any_fg = fn(cnode->func_graph());
|
||||
if (!utils::isa<FuncGraphPtr>(any_fg)) {
|
||||
MS_LOG(EXCEPTION) << "VisitError, fn not return the same type FuncGraphPtr";
|
||||
}
|
||||
new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<FuncGraphPtr>(any_fg));
|
||||
} else {
|
||||
any_fg = fn(cnode->func_graph_as_var());
|
||||
if (utils::isa<VarPtr>(any_fg)) {
|
||||
new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<VarPtr>(any_fg));
|
||||
} else if (utils::isa<FuncGraphPtr>(any_fg)) {
|
||||
new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<FuncGraphPtr>(any_fg));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "VisitError, fn not return VarPtr or FuncGraphPtr";
|
||||
}
|
||||
}
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
*output = new_cnode;
|
||||
}
|
||||
|
||||
void DefaultVisitor::Visit(const ValueNodePtr &vnode, const VisitFn &fn, AnfNodePtr *output) const {
|
||||
const BaseRef &value = utils::cast<ValuePtr>(fn(vnode->value()));
|
||||
if (utils::isa<ValuePtr>(value)) {
|
||||
if (output != nullptr) {
|
||||
auto ct = NewValueNode(utils::cast<ValuePtr>(value));
|
||||
ct->set_abstract(vnode->abstract());
|
||||
*output = ct;
|
||||
}
|
||||
return;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Visit result is not ValuePtr.";
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LIFT_SRC_PASS_COMMON_VISIT_H_
|
||||
#define MINDSPORE_LIFT_SRC_PASS_COMMON_VISIT_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#include <stdexcept>
|
||||
#include <list>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "mindspore/core/base/base.h"
|
||||
#include "base/base_ref.h"
|
||||
|
||||
namespace mindspore {
|
||||
using VisitFn = std::function<BaseRef(const BaseRef &)>;
|
||||
|
||||
class Visitor {
|
||||
public:
|
||||
virtual void SetFn(VisitFn fn) = 0;
|
||||
virtual bool Visit(const BaseRef &e, BaseRef *out) const = 0;
|
||||
virtual bool Visit(const VectorRef &e, BaseRef *out) const = 0;
|
||||
virtual ~Visitor() = default;
|
||||
};
|
||||
|
||||
class DefaultVisitor : public Visitor {
|
||||
public:
|
||||
DefaultVisitor() : fn_(nullptr) {}
|
||||
~DefaultVisitor() override = default;
|
||||
void SetFn(VisitFn fn) override { fn_ = fn; };
|
||||
bool Visit(const VectorRef &e, BaseRef *out) const override;
|
||||
bool Visit(const BaseRef &e, BaseRef *out) const override;
|
||||
void Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr *output) const;
|
||||
void Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr *output) const;
|
||||
void Visit(const ValueNodePtr &vnode, const VisitFn &fn, AnfNodePtr *output) const;
|
||||
|
||||
VisitFn fn_;
|
||||
};
|
||||
|
||||
std::shared_ptr<VectorRef> ExpandList(const std::vector<BaseRef> &list);
|
||||
bool CheckIfNeedExpand(const std::vector<BaseRef> &list);
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LIFT_SRC_PASS_COMMON_VISIT_H_
|
||||
|
|
@ -14,12 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mindspore/lite/src/gllo/fusion/conv_activation_fusion.h"
|
||||
#include "src/gllo/fusion/conv_activation_fusion.h"
|
||||
#include <memory>
|
||||
#include "mindspore/lite/schema/inner/model_generated.h"
|
||||
#include "mindspore/lite/src/ir/primitive_t_value.h"
|
||||
#include "mindspore/ccsrc/utils/utils.h"
|
||||
#include "mindspore/lite/src/gllo/common/utils.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/ir/primitive_t_value.h"
|
||||
#include "utils/utils.h"
|
||||
#include "src/gllo/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_
|
||||
|
||||
#include <string>
|
||||
#include "mindspore/lite/src/gllo/common/optimizer.h"
|
||||
#include "src/gllo/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
|
|
@ -13,13 +13,13 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.h"
|
||||
#include <mindspore/lite/src/param_value_lite.h>
|
||||
#include "src/gllo/fusion/conv_biasadd_fusion.h"
|
||||
#include <memory>
|
||||
#include "mindspore/lite/schema/inner/model_generated.h"
|
||||
#include "mindspore/lite/src/ir/primitive_t_value.h"
|
||||
#include "mindspore/ccsrc/utils/utils.h"
|
||||
#include "mindspore/lite/src/gllo/common/utils.h"
|
||||
#include "src/param_value_lite.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/ir/primitive_t_value.h"
|
||||
#include "utils/utils.h"
|
||||
#include "src/gllo/common/gllo_utils.h"
|
||||
#include "securec/include/securec.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
|
@ -142,7 +142,7 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons
|
|||
CheckIfCNodeIsNull(conv_node);
|
||||
GenConvNewBias(func_graph, conv_node, add_node);
|
||||
auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveTValue>>(conv_node->input(0));
|
||||
MS_ASSERT(primitiveT_value);
|
||||
MS_ASSERT(primitiveT_value != nullptr);
|
||||
auto type = primitiveT_value->GetPrimitiveT()->value.type;
|
||||
if (type == schema::PrimitiveType_Conv2D) {
|
||||
primitiveT_value->GetPrimitiveT()->value.AsConv2D()->hasBias = true;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_
|
||||
|
||||
#include "mindspore/lite/src/gllo/common/optimizer.h"
|
||||
#include "src/gllo/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
|
|
@ -14,13 +14,13 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mindspore/lite/src/gllo/fusion/conv_bn_fusion.h"
|
||||
#include <mindspore/lite/src/param_value_lite.h>
|
||||
#include "src/gllo/fusion/conv_bn_fusion.h"
|
||||
#include <memory>
|
||||
#include "mindspore/lite/schema/inner/model_generated.h"
|
||||
#include "mindspore/lite/src/ir/primitive_t_value.h"
|
||||
#include "mindspore/ccsrc/utils/utils.h"
|
||||
#include "mindspore/lite/src/gllo/common/utils.h"
|
||||
#include "src/param_value_lite.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/ir/primitive_t_value.h"
|
||||
#include "utils/utils.h"
|
||||
#include "src/gllo/common/gllo_utils.h"
|
||||
#include "securec/include/securec.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_
|
||||
|
||||
#include "mindspore/lite/src/gllo/fusion/conv_transform_fusion.h"
|
||||
#include "src/gllo/fusion/conv_transform_fusion.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
class ConvBatchNormFusion : public ConvTransformFusion {
|
||||
|
|
|
@ -14,13 +14,13 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mindspore/lite/src/gllo/fusion/conv_scale_fusion.h"
|
||||
#include "src/gllo/fusion/conv_scale_fusion.h"
|
||||
#include <memory>
|
||||
#include "mindspore/lite/src/param_value_lite.h"
|
||||
#include "mindspore/lite/schema/inner/model_generated.h"
|
||||
#include "mindspore/lite/src/ir/primitive_t_value.h"
|
||||
#include "mindspore/ccsrc/utils/utils.h"
|
||||
#include "mindspore/lite/src/gllo/common/utils.h"
|
||||
#include "src/param_value_lite.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/ir/primitive_t_value.h"
|
||||
#include "utils/utils.h"
|
||||
#include "src/gllo/common/gllo_utils.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "securec/include/securec.h"
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_
|
||||
|
||||
#include "mindspore/lite/src/gllo/fusion/conv_transform_fusion.h"
|
||||
#include "src/gllo/fusion/conv_transform_fusion.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
class ConvScaleFusion : public ConvTransformFusion {
|
||||
|
|
|
@ -14,13 +14,13 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mindspore/lite/src/gllo/fusion/conv_transform_fusion.h"
|
||||
#include "src/gllo/fusion/conv_transform_fusion.h"
|
||||
#include <memory>
|
||||
#include "mindspore/lite/src/param_value_lite.h"
|
||||
#include "mindspore/lite/schema/inner/model_generated.h"
|
||||
#include "mindspore/lite/src/ir/primitive_t_value.h"
|
||||
#include "mindspore/ccsrc/utils/utils.h"
|
||||
#include "mindspore/lite/src/gllo/common/utils.h"
|
||||
#include "src/param_value_lite.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/ir/primitive_t_value.h"
|
||||
#include "utils/utils.h"
|
||||
#include "src/gllo/common/gllo_utils.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "securec/include/securec.h"
|
||||
|
||||
|
@ -78,6 +78,16 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co
|
|||
GenNewConvTensor(func_graph, conv_node, kernel_nums, trans_scale, trans_bias);
|
||||
delete[] trans_bias;
|
||||
delete[] trans_scale;
|
||||
auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveTValue>>(conv_node->input(0));
|
||||
MS_ASSERT(primitiveT_value != nullptr);
|
||||
auto type = primitiveT_value->GetPrimitiveT()->value.type;
|
||||
if (type == schema::PrimitiveType_Conv2D) {
|
||||
primitiveT_value->GetPrimitiveT()->value.AsConv2D()->hasBias = true;
|
||||
} else if (type == schema::PrimitiveType_DepthwiseConv2D) {
|
||||
primitiveT_value->GetPrimitiveT()->value.AsDepthwiseConv2D()->hasBias = true;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported opType, " << type;
|
||||
}
|
||||
return pre_node;
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_
|
||||
|
||||
#include <string>
|
||||
#include "mindspore/lite/src/gllo/common/optimizer.h"
|
||||
#include "src/gllo/common/optimizer.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
class ConvTransformFusion : public PatternProcessPass {
|
||||
|
|
|
@ -63,6 +63,8 @@ if(BUILD_CONVERTER)
|
|||
${CCSRC_DIR}/pybind_api/export_flags.cc
|
||||
${CCSRC_DIR}/utils/context/context_extends.cc
|
||||
${CCSRC_DIR}/frontend/parallel/costmodel_context.cc
|
||||
${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc
|
||||
${CCSRC_DIR}/backend/optimizer/common/visit.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../src/common/graph_utils_extends.cc
|
||||
)
|
||||
else()
|
||||
|
@ -202,12 +204,14 @@ if(BUILD_CONVERTER)
|
|||
${LITE_DIR}/tools/converter/converter.cc
|
||||
${LITE_DIR}/tools/converter/parser/onnx/onnx.pb.cc
|
||||
${LITE_DIR}/test/st/converter_test.cc
|
||||
${LITE_DIR}/test/ut/src/gllo/fusion/conv_activation_fusion_test.cc
|
||||
${LITE_DIR}/test/ut/src/gllo/fusion/conv_biasadd_fusion_test.cc
|
||||
${LITE_DIR}/test/ut/src/gllo/fusion/conv_bn_fusion_test.cc
|
||||
${LITE_DIR}/test/ut/src/gllo/fusion/conv_scale_fusion_test.cc
|
||||
${LITE_DIR}/src/gllo/common/node_pass.cc
|
||||
${LITE_DIR}/src/gllo/common/optimizer.cc
|
||||
${LITE_DIR}/src/gllo/common/pass_manager.cc
|
||||
${LITE_DIR}/src/gllo/common/pattern_engine.cc
|
||||
${LITE_DIR}/src/gllo/common/visit.cc
|
||||
${LITE_DIR}/src/gllo/common/utils.cc
|
||||
${LITE_DIR}/src/gllo/common/gllo_utils.cc
|
||||
${LITE_DIR}/src/gllo/fusion/conv_biasadd_fusion.cc
|
||||
${LITE_DIR}/src/gllo/fusion/conv_activation_fusion.cc
|
||||
${LITE_DIR}/src/gllo/fusion/conv_transform_fusion.cc
|
||||
|
|
|
@ -0,0 +1,184 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "include/model.h"
|
||||
#include "common/common_test.h"
|
||||
#include "include/lite_session.h"
|
||||
#include "include/context.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/anf_transform.h"
|
||||
#include "src/common/anf_exporter/anf_exporter.h"
|
||||
|
||||
namespace mindspore {
|
||||
class ConvActivationFusionTest : public mindspore::Common {
|
||||
public:
|
||||
ConvActivationFusionTest() = default;
|
||||
};
|
||||
using MetaGraphTptr = std::shared_ptr<schema::MetaGraphT>;
|
||||
using CNodeTptr = std::unique_ptr<schema::CNodeT>;
|
||||
|
||||
namespace {
|
||||
CNodeTptr BuildConv2D() {
|
||||
auto convNode = std::make_unique<schema::CNodeT>();
|
||||
convNode->inputIndex = {0, 1};
|
||||
convNode->outputIndex = {2};
|
||||
convNode->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
convNode->primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
auto prim1 = new schema::Conv2DT;
|
||||
prim1->padMode = schema::PadMode_SAME;
|
||||
prim1->format = schema::Format_NHWC;
|
||||
prim1->strideH = 1;
|
||||
prim1->strideW = 1;
|
||||
prim1->kernelH = 3;
|
||||
prim1->kernelW = 3;
|
||||
prim1->dilateH = 1;
|
||||
prim1->dilateW = 1;
|
||||
prim1->channelOut = 3;
|
||||
convNode->primitive->value.value = prim1;
|
||||
convNode->name = "Conv2D";
|
||||
return convNode;
|
||||
}
|
||||
CNodeTptr BuildDepthwiseConv2D() {
|
||||
auto convNode = std::make_unique<schema::CNodeT>();
|
||||
convNode->inputIndex = {0, 1};
|
||||
convNode->outputIndex = {2};
|
||||
convNode->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
convNode->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
||||
auto prim1 = new schema::DepthwiseConv2DT;
|
||||
prim1->padMode = schema::PadMode_SAME;
|
||||
prim1->format = schema::Format_NHWC;
|
||||
prim1->strideH = 1;
|
||||
prim1->strideW = 1;
|
||||
prim1->kernelH = 3;
|
||||
prim1->kernelW = 3;
|
||||
prim1->dilateH = 1;
|
||||
prim1->dilateW = 1;
|
||||
prim1->channelIn = 1;
|
||||
prim1->channelMultiplier = 3;
|
||||
convNode->primitive->value.value = prim1;
|
||||
convNode->name = "Conv2D";
|
||||
return convNode;
|
||||
}
|
||||
|
||||
MetaGraphTptr BuildGraph(schema::PrimitiveType conv_type,
|
||||
schema::ActivationType activation_type) {
|
||||
auto meta_graph = std::make_shared<schema::MetaGraphT>();
|
||||
meta_graph->name = "graph";
|
||||
// conv node
|
||||
CNodeTptr convNode;
|
||||
if (conv_type == schema::PrimitiveType_Conv2D) {
|
||||
convNode = BuildConv2D();
|
||||
} else {
|
||||
convNode = BuildDepthwiseConv2D();
|
||||
}
|
||||
meta_graph->nodes.emplace_back(std::move(convNode));
|
||||
|
||||
// relu node
|
||||
auto next_node = std::make_unique<schema::CNodeT>();
|
||||
next_node->inputIndex = {2};
|
||||
next_node->outputIndex = {3};
|
||||
next_node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
next_node->primitive->value.type = schema::PrimitiveType_Activation;
|
||||
auto prim2 = new schema::ActivationT;
|
||||
prim2->type = activation_type;
|
||||
next_node->primitive->value.value = prim2;
|
||||
next_node->name = "activation";
|
||||
meta_graph->nodes.emplace_back(std::move(next_node));
|
||||
|
||||
meta_graph->inputIndex = {0};
|
||||
meta_graph->outputIndex = {3};
|
||||
|
||||
// input 0: data
|
||||
auto input0 = std::make_unique<schema::TensorT>();
|
||||
input0->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input0->format = schema::Format_NHWC;
|
||||
input0->dataType = TypeId::kNumberTypeFloat32;
|
||||
input0->dims = {1, 5, 5, 3};
|
||||
input0->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(input0));
|
||||
|
||||
// input 1: weight
|
||||
auto input1 = std::make_unique<schema::TensorT>();
|
||||
input1->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input1->format = schema::Format_KHWC;
|
||||
input1->dataType = TypeId::kNumberTypeFloat32;
|
||||
input1->dims = {8, 3, 3, 3};
|
||||
input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3);
|
||||
meta_graph->allTensors.emplace_back(std::move(input1));
|
||||
|
||||
// conv output
|
||||
auto conv_output = std::make_unique<schema::TensorT>();
|
||||
conv_output->nodeType = schema::NodeType::NodeType_Parameter;
|
||||
conv_output->format = schema::Format_NHWC;
|
||||
conv_output->dataType = TypeId::kNumberTypeFloat32;
|
||||
conv_output->dims = {1, 5, 5, 8};
|
||||
meta_graph->allTensors.emplace_back(std::move(conv_output));
|
||||
|
||||
// final output
|
||||
auto output = std::make_unique<schema::TensorT>();
|
||||
output->nodeType = schema::NodeType::NodeType_Parameter;
|
||||
output->format = schema::Format_NHWC;
|
||||
output->dataType = TypeId::kNumberTypeFloat32;
|
||||
output->dims = {1, 5, 5, 8};
|
||||
meta_graph->allTensors.emplace_back(std::move(output));
|
||||
return meta_graph;
|
||||
}
|
||||
} // namespace
|
||||
TEST_F(ConvActivationFusionTest, TestConvReluNode) {
|
||||
auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2D, schema::ActivationType_RELU);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto anf_transform = new lite::AnfTransform();
|
||||
auto new_graph = anf_transform->Transform(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
|
||||
for (auto &cnode : new_meta_graph->nodes) {
|
||||
ASSERT_EQ(cnode->primitive->value.AsConv2D()->activationType, schema::ActivationType_RELU);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ConvActivationFusionTest, TestConvRelu6Node) {
|
||||
auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2D, schema::ActivationType_RELU6);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto anf_transform = new lite::AnfTransform();
|
||||
auto new_graph = anf_transform->Transform(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
|
||||
for (auto &cnode : new_meta_graph->nodes) {
|
||||
ASSERT_EQ(cnode->primitive->value.AsConv2D()->activationType, schema::ActivationType_RELU6);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ConvActivationFusionTest, TestBadCase_ConvRelu) {
|
||||
auto meta_graph = BuildGraph(schema::PrimitiveType_DepthwiseConv2D, schema::ActivationType_LEAKY_RELU);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto anf_transform = new lite::AnfTransform();
|
||||
auto new_graph = anf_transform->Transform(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 2);
|
||||
for (auto &cnode : new_meta_graph->nodes) {
|
||||
if (cnode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) {
|
||||
ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->activationType, schema::ActivationType_NO_ACTIVATION);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,194 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "include/model.h"
|
||||
#include "common/common_test.h"
|
||||
#include "include/lite_session.h"
|
||||
#include "include/context.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/anf_transform.h"
|
||||
#include "src/common/anf_exporter/anf_exporter.h"
|
||||
|
||||
namespace mindspore {
|
||||
class ConvBiasAddFusionTest : public mindspore::Common {
|
||||
public:
|
||||
ConvBiasAddFusionTest() = default;
|
||||
};
|
||||
using MetaGraphTptr = std::shared_ptr<schema::MetaGraphT>;
|
||||
using CNodeTptr = std::unique_ptr<schema::CNodeT>;
|
||||
|
||||
namespace {
|
||||
CNodeTptr BuildConv2D() {
|
||||
auto convNode = std::make_unique<schema::CNodeT>();
|
||||
convNode->inputIndex = {0, 1};
|
||||
convNode->outputIndex = {2};
|
||||
convNode->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
convNode->primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
auto prim1 = new schema::Conv2DT;
|
||||
prim1->padMode = schema::PadMode_SAME;
|
||||
prim1->format = schema::Format_NHWC;
|
||||
prim1->strideH = 1;
|
||||
prim1->strideW = 1;
|
||||
prim1->kernelH = 3;
|
||||
prim1->kernelW = 3;
|
||||
prim1->dilateH = 1;
|
||||
prim1->dilateW = 1;
|
||||
prim1->channelOut = 3;
|
||||
convNode->primitive->value.value = prim1;
|
||||
convNode->name = "Conv2D";
|
||||
return convNode;
|
||||
}
|
||||
CNodeTptr BuildDepthwiseConv2D() {
|
||||
auto convNode = std::make_unique<schema::CNodeT>();
|
||||
convNode->inputIndex = {0, 1};
|
||||
convNode->outputIndex = {2};
|
||||
convNode->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
convNode->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
||||
auto prim1 = new schema::DepthwiseConv2DT;
|
||||
prim1->padMode = schema::PadMode_SAME;
|
||||
prim1->format = schema::Format_NHWC;
|
||||
prim1->strideH = 1;
|
||||
prim1->strideW = 1;
|
||||
prim1->kernelH = 3;
|
||||
prim1->kernelW = 3;
|
||||
prim1->dilateH = 1;
|
||||
prim1->dilateW = 1;
|
||||
prim1->channelIn = 1;
|
||||
prim1->channelMultiplier = 3;
|
||||
convNode->primitive->value.value = prim1;
|
||||
convNode->name = "Conv2D";
|
||||
return convNode;
|
||||
}
|
||||
|
||||
MetaGraphTptr BuildGraph(schema::PrimitiveType conv_type,
|
||||
schema::PrimitiveType add_type) {
|
||||
auto meta_graph = std::make_shared<schema::MetaGraphT>();
|
||||
meta_graph->name = "graph";
|
||||
// conv node
|
||||
CNodeTptr convNode;
|
||||
if (conv_type == schema::PrimitiveType_Conv2D) {
|
||||
convNode = BuildConv2D();
|
||||
} else {
|
||||
convNode = BuildDepthwiseConv2D();
|
||||
}
|
||||
|
||||
meta_graph->nodes.emplace_back(std::move(convNode));
|
||||
|
||||
// biasadd node
|
||||
auto biasadd_node = std::make_unique<schema::CNodeT>();
|
||||
biasadd_node->inputIndex = {2, 3};
|
||||
biasadd_node->outputIndex = {4};
|
||||
biasadd_node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
biasadd_node->primitive->value.type = add_type;
|
||||
auto prim2 = new schema::BiasAddT;
|
||||
biasadd_node->primitive->value.value = prim2;
|
||||
biasadd_node->name = "BiasAdd";
|
||||
meta_graph->nodes.emplace_back(std::move(biasadd_node));
|
||||
|
||||
meta_graph->inputIndex = {0};
|
||||
meta_graph->outputIndex = {4};
|
||||
|
||||
// input 0: data
|
||||
auto input0 = std::make_unique<schema::TensorT>();
|
||||
input0->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input0->format = schema::Format_NHWC;
|
||||
input0->dataType = TypeId::kNumberTypeFloat32;
|
||||
input0->dims = {1, 5, 5, 3};
|
||||
input0->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(input0));
|
||||
|
||||
// input 1: weight
|
||||
auto input1 = std::make_unique<schema::TensorT>();
|
||||
input1->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input1->format = schema::Format_KHWC;
|
||||
input1->dataType = TypeId::kNumberTypeFloat32;
|
||||
input1->dims = {8, 3, 3, 3};
|
||||
input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3);
|
||||
meta_graph->allTensors.emplace_back(std::move(input1));
|
||||
|
||||
// conv output
|
||||
auto conv_output = std::make_unique<schema::TensorT>();
|
||||
conv_output->nodeType = schema::NodeType::NodeType_Parameter;
|
||||
conv_output->format = schema::Format_NHWC;
|
||||
conv_output->dataType = TypeId::kNumberTypeFloat32;
|
||||
conv_output->dims = {1, 5, 5, 8};
|
||||
meta_graph->allTensors.emplace_back(std::move(conv_output));
|
||||
|
||||
// input2: bias
|
||||
auto input2 = std::make_unique<schema::TensorT>();
|
||||
input2->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input2->format = schema::Format_NHWC;
|
||||
input2->dataType = TypeId::kNumberTypeFloat32;
|
||||
input2->dims = {1, 5, 5, 8};
|
||||
input2->data.resize(sizeof(float) * 8 * 5 * 5);
|
||||
meta_graph->allTensors.emplace_back(std::move(input2));
|
||||
|
||||
// final output
|
||||
auto output = std::make_unique<schema::TensorT>();
|
||||
output->nodeType = schema::NodeType::NodeType_Parameter;
|
||||
output->format = schema::Format_NHWC;
|
||||
output->dataType = TypeId::kNumberTypeFloat32;
|
||||
output->dims = {1, 5, 5, 8};
|
||||
meta_graph->allTensors.emplace_back(std::move(output));
|
||||
return meta_graph;
|
||||
}
|
||||
} // namespace
|
||||
TEST_F(ConvBiasAddFusionTest, TestConvAddNode) {
|
||||
auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2D, schema::PrimitiveType_BiasAdd);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto anf_transform = new lite::AnfTransform();
|
||||
auto new_graph = anf_transform->Transform(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
|
||||
for (auto &cnode : new_meta_graph->nodes) {
|
||||
ASSERT_EQ(cnode->primitive->value.AsConv2D()->hasBias, true);
|
||||
}
|
||||
MS_LOG(INFO) << "Passed";
|
||||
}
|
||||
|
||||
TEST_F(ConvBiasAddFusionTest, TestDeptiwiseConvAddNode) {
|
||||
auto meta_graph = BuildGraph(schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto anf_transform = new lite::AnfTransform();
|
||||
auto new_graph = anf_transform->Transform(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
|
||||
for (auto &cnode : new_meta_graph->nodes) {
|
||||
ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->hasBias, true);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ConvBiasAddFusionTest, TestBadCase_ConvAdd) {
|
||||
auto meta_graph = BuildGraph(schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_MatMul);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto anf_transform = new lite::AnfTransform();
|
||||
auto new_graph = anf_transform->Transform(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 2);
|
||||
for (auto &cnode : new_meta_graph->nodes) {
|
||||
if (cnode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) {
|
||||
ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->hasBias, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,296 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "include/model.h"
|
||||
#include "common/common_test.h"
|
||||
#include "include/lite_session.h"
|
||||
#include "include/context.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "mindspore/core/utils/log_adapter.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/anf_transform.h"
|
||||
#include "src/common/anf_exporter/anf_exporter.h"
|
||||
|
||||
namespace mindspore {
|
||||
class ConvBNFusionTest : public mindspore::Common {
|
||||
public:
|
||||
ConvBNFusionTest() = default;
|
||||
};
|
||||
using MetaGraphTptr = std::shared_ptr<schema::MetaGraphT>;
|
||||
using CNodeTptr = std::unique_ptr<schema::CNodeT>;
|
||||
|
||||
namespace {
|
||||
CNodeTptr BuildConv2D() {
|
||||
auto convNode = std::make_unique<schema::CNodeT>();
|
||||
convNode->inputIndex = {0, 1};
|
||||
convNode->outputIndex = {2};
|
||||
convNode->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
convNode->primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
auto prim1 = new schema::Conv2DT;
|
||||
prim1->padMode = schema::PadMode_SAME;
|
||||
prim1->format = schema::Format_NHWC;
|
||||
prim1->strideH = 1;
|
||||
prim1->strideW = 1;
|
||||
prim1->kernelH = 3;
|
||||
prim1->kernelW = 3;
|
||||
prim1->dilateH = 1;
|
||||
prim1->dilateW = 1;
|
||||
prim1->channelOut = 3;
|
||||
convNode->primitive->value.value = prim1;
|
||||
convNode->name = "Conv2D";
|
||||
return convNode;
|
||||
}
|
||||
CNodeTptr BuildDepthwiseConv2D() {
|
||||
auto convNode = std::make_unique<schema::CNodeT>();
|
||||
convNode->inputIndex = {0, 1, 2};
|
||||
convNode->outputIndex = {3};
|
||||
convNode->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
convNode->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
||||
auto prim1 = new schema::DepthwiseConv2DT;
|
||||
prim1->padMode = schema::PadMode_SAME;
|
||||
prim1->format = schema::Format_NHWC;
|
||||
prim1->strideH = 1;
|
||||
prim1->strideW = 1;
|
||||
prim1->kernelH = 3;
|
||||
prim1->kernelW = 3;
|
||||
prim1->dilateH = 1;
|
||||
prim1->dilateW = 1;
|
||||
prim1->channelIn = 1;
|
||||
prim1->channelMultiplier = 3;
|
||||
|
||||
convNode->primitive->value.value = prim1;
|
||||
convNode->name = "Conv2D";
|
||||
return convNode;
|
||||
}
|
||||
// caffe bn op has 3 inputs
|
||||
MetaGraphTptr BuildCaffeGraph(schema::PrimitiveType conv_type) {
|
||||
auto meta_graph = std::make_shared<schema::MetaGraphT>();
|
||||
meta_graph->name = "graph";
|
||||
// conv node
|
||||
CNodeTptr convNode;
|
||||
if (conv_type == schema::PrimitiveType_Conv2D) {
|
||||
convNode = BuildConv2D();
|
||||
} else {
|
||||
convNode = BuildDepthwiseConv2D();
|
||||
}
|
||||
|
||||
meta_graph->nodes.emplace_back(std::move(convNode));
|
||||
|
||||
// bn_node
|
||||
auto bn_node = std::make_unique<schema::CNodeT>();
|
||||
bn_node->inputIndex = {2, 3, 4};
|
||||
bn_node->outputIndex = {5};
|
||||
bn_node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
bn_node->primitive->value.type = schema::PrimitiveType_CaffeBatchNorm;
|
||||
auto prim2 = new schema::CaffeBatchNormT;
|
||||
bn_node->primitive->value.value = prim2;
|
||||
bn_node->name = "bn";
|
||||
meta_graph->nodes.emplace_back(std::move(bn_node));
|
||||
|
||||
// input 0: data
|
||||
auto input0 = std::make_unique<schema::TensorT>();
|
||||
input0->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input0->format = schema::Format_NHWC;
|
||||
input0->dataType = TypeId::kNumberTypeFloat32;
|
||||
input0->dims = {1, 5, 5, 3};
|
||||
input0->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(input0));
|
||||
|
||||
// input 1: weight
|
||||
auto input1 = std::make_unique<schema::TensorT>();
|
||||
input1->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input1->format = schema::Format_KHWC;
|
||||
input1->dataType = TypeId::kNumberTypeFloat32;
|
||||
input1->dims = {8, 3, 3, 3};
|
||||
input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3);
|
||||
meta_graph->allTensors.emplace_back(std::move(input1));
|
||||
|
||||
// conv output
|
||||
auto conv_output = std::make_unique<schema::TensorT>();
|
||||
conv_output->nodeType = schema::NodeType::NodeType_Parameter;
|
||||
conv_output->format = schema::Format_NHWC;
|
||||
conv_output->dataType = TypeId::kNumberTypeFloat32;
|
||||
conv_output->dims = {1, 5, 5, 8};
|
||||
meta_graph->allTensors.emplace_back(std::move(conv_output));
|
||||
|
||||
// caffe bn : mean
|
||||
auto input2 = std::make_unique<schema::TensorT>();
|
||||
input2->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input2->format = schema::Format_NHWC;
|
||||
input2->dataType = TypeId::kNumberTypeFloat32;
|
||||
input2->dims = {1, 5, 5, 8};
|
||||
input2->data.resize(sizeof(float) * 8 * 5 * 5);
|
||||
meta_graph->allTensors.emplace_back(std::move(input2));
|
||||
|
||||
// caffe bn : var
|
||||
auto input3 = std::make_unique<schema::TensorT>();
|
||||
input3->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input3->format = schema::Format_NHWC;
|
||||
input3->dataType = TypeId::kNumberTypeFloat32;
|
||||
input3->dims = {1, 5, 5, 8};
|
||||
input3->data.resize(sizeof(float) * 8 * 5 * 5);
|
||||
meta_graph->allTensors.emplace_back(std::move(input3));
|
||||
|
||||
|
||||
// final bn output
|
||||
auto output = std::make_unique<schema::TensorT>();
|
||||
output->nodeType = schema::NodeType::NodeType_Parameter;
|
||||
output->format = schema::Format_NHWC;
|
||||
output->dataType = TypeId::kNumberTypeFloat32;
|
||||
output->dims = {1, 5, 5, 8};
|
||||
meta_graph->allTensors.emplace_back(std::move(output));
|
||||
meta_graph->inputIndex = {0};
|
||||
meta_graph->outputIndex = {5};
|
||||
return meta_graph;
|
||||
}
|
||||
|
||||
// tf bn op has 4 inputs
|
||||
MetaGraphTptr BuildTFGraph(schema::PrimitiveType conv_type) {
|
||||
auto meta_graph = std::make_shared<schema::MetaGraphT>();
|
||||
meta_graph->name = "graph";
|
||||
// conv node
|
||||
CNodeTptr convNode;
|
||||
if (conv_type == schema::PrimitiveType_Conv2D) {
|
||||
convNode = BuildConv2D();
|
||||
} else {
|
||||
convNode = BuildDepthwiseConv2D();
|
||||
}
|
||||
|
||||
meta_graph->nodes.emplace_back(std::move(convNode));
|
||||
|
||||
// bn_node
|
||||
auto bn_node = std::make_unique<schema::CNodeT>();
|
||||
bn_node->inputIndex = {3, 4, 5, 6, 7};
|
||||
bn_node->outputIndex = {8};
|
||||
bn_node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
bn_node->primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
|
||||
auto prim2 = new schema::FusedBatchNormT;
|
||||
bn_node->primitive->value.value = prim2;
|
||||
bn_node->name = "bn";
|
||||
meta_graph->nodes.emplace_back(std::move(bn_node));
|
||||
|
||||
// input 0: data
|
||||
auto input0 = std::make_unique<schema::TensorT>();
|
||||
input0->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input0->format = schema::Format_NHWC;
|
||||
input0->dataType = TypeId::kNumberTypeFloat32;
|
||||
input0->dims = {1, 5, 5, 3};
|
||||
input0->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(input0));
|
||||
|
||||
|
||||
// input 1: conv_bias
|
||||
auto input11 = std::make_unique<schema::TensorT>();
|
||||
input11->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input11->format = schema::Format_KHWC;
|
||||
input11->dataType = TypeId::kNumberTypeFloat32;
|
||||
input11->dims = {8, 3, 3, 3};
|
||||
input11->data.resize(sizeof(float) * 8 * 3 * 3 * 3);
|
||||
meta_graph->allTensors.emplace_back(std::move(input11));
|
||||
|
||||
// input 1: weight
|
||||
auto input1 = std::make_unique<schema::TensorT>();
|
||||
input1->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input1->format = schema::Format_KHWC;
|
||||
input1->dataType = TypeId::kNumberTypeFloat32;
|
||||
input1->dims = {8, 3, 3, 3};
|
||||
input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3);
|
||||
meta_graph->allTensors.emplace_back(std::move(input1));
|
||||
|
||||
// conv output
|
||||
auto conv_output = std::make_unique<schema::TensorT>();
|
||||
conv_output->nodeType = schema::NodeType::NodeType_Parameter;
|
||||
conv_output->format = schema::Format_NHWC;
|
||||
conv_output->dataType = TypeId::kNumberTypeFloat32;
|
||||
conv_output->dims = {1, 5, 5, 8};
|
||||
meta_graph->allTensors.emplace_back(std::move(conv_output));
|
||||
|
||||
// tflite bn : scale
|
||||
auto input2 = std::make_unique<schema::TensorT>();
|
||||
input2->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input2->format = schema::Format_NHWC;
|
||||
input2->dataType = TypeId::kNumberTypeFloat32;
|
||||
input2->dims = {1, 5, 5, 8};
|
||||
input2->data.resize(sizeof(float) * 8 * 5 * 5);
|
||||
meta_graph->allTensors.emplace_back(std::move(input2));
|
||||
|
||||
// tflite bn : bias
|
||||
auto input3 = std::make_unique<schema::TensorT>();
|
||||
input3->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input3->format = schema::Format_NHWC;
|
||||
input3->dataType = TypeId::kNumberTypeFloat32;
|
||||
input3->dims = {1, 5, 5, 8};
|
||||
input3->data.resize(sizeof(float) * 8 * 5 * 5);
|
||||
meta_graph->allTensors.emplace_back(std::move(input3));
|
||||
|
||||
// tflite bn : mean
|
||||
auto input4 = std::make_unique<schema::TensorT>();
|
||||
input4->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input4->format = schema::Format_NHWC;
|
||||
input4->dataType = TypeId::kNumberTypeFloat32;
|
||||
input4->dims = {1, 5, 5, 8};
|
||||
input4->data.resize(sizeof(float) * 8 * 5 * 5);
|
||||
meta_graph->allTensors.emplace_back(std::move(input4));
|
||||
|
||||
// tflite bn : var
|
||||
auto input5 = std::make_unique<schema::TensorT>();
|
||||
input5->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input5->format = schema::Format_NHWC;
|
||||
input5->dataType = TypeId::kNumberTypeFloat32;
|
||||
input5->dims = {1, 5, 5, 8};
|
||||
input5->data.resize(sizeof(float) * 8 * 5 * 5);
|
||||
meta_graph->allTensors.emplace_back(std::move(input5));
|
||||
|
||||
// final output
|
||||
auto output = std::make_unique<schema::TensorT>();
|
||||
output->nodeType = schema::NodeType::NodeType_Parameter;
|
||||
output->format = schema::Format_NHWC;
|
||||
output->dataType = TypeId::kNumberTypeFloat32;
|
||||
output->dims = {1, 5, 5, 8};
|
||||
meta_graph->allTensors.emplace_back(std::move(output));
|
||||
meta_graph->inputIndex = {0};
|
||||
meta_graph->outputIndex = {8};
|
||||
return meta_graph;
|
||||
}
|
||||
} // namespace
|
||||
TEST_F(ConvBNFusionTest, TestConvAddNode) {
|
||||
auto meta_graph = BuildCaffeGraph(schema::PrimitiveType_Conv2D);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto anf_transform = new lite::AnfTransform();
|
||||
auto new_graph = anf_transform->Transform(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
|
||||
for (auto &cnode : new_meta_graph->nodes) {
|
||||
ASSERT_EQ(cnode->primitive->value.AsConv2D()->hasBias, true);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ConvBNFusionTest, TestDeptiwiseConvAddNode) {
|
||||
auto meta_graph = BuildTFGraph(schema::PrimitiveType_DepthwiseConv2D);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto anf_transform = new lite::AnfTransform();
|
||||
auto new_graph = anf_transform->Transform(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
|
||||
for (auto &cnode : new_meta_graph->nodes) {
|
||||
ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->hasBias, true);
|
||||
}
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,221 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "include/model.h"
|
||||
#include "common/common_test.h"
|
||||
#include "include/lite_session.h"
|
||||
#include "include/context.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/anf_transform.h"
|
||||
#include "src/common/anf_exporter/anf_exporter.h"
|
||||
|
||||
namespace mindspore {
|
||||
class ConvScaleFusionTest : public mindspore::Common {
|
||||
public:
|
||||
ConvScaleFusionTest() = default;
|
||||
};
|
||||
using MetaGraphTptr = std::shared_ptr<schema::MetaGraphT>;
|
||||
using CNodeTptr = std::unique_ptr<schema::CNodeT>;
|
||||
|
||||
namespace {
|
||||
// conv has 2 inputs
|
||||
CNodeTptr BuildConv2D(int with_bias_flag) {
|
||||
auto convNode = std::make_unique<schema::CNodeT>();
|
||||
if (with_bias_flag) {
|
||||
convNode->inputIndex = {0, 1, 2};
|
||||
convNode->outputIndex = {3};
|
||||
} else {
|
||||
convNode->inputIndex = {0, 1};
|
||||
convNode->outputIndex = {2};
|
||||
}
|
||||
convNode->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
convNode->primitive->value.type = schema::PrimitiveType_Conv2D;
|
||||
auto prim1 = new schema::Conv2DT;
|
||||
prim1->padMode = schema::PadMode_SAME;
|
||||
prim1->format = schema::Format_NHWC;
|
||||
prim1->strideH = 1;
|
||||
prim1->strideW = 1;
|
||||
prim1->kernelH = 3;
|
||||
prim1->kernelW = 3;
|
||||
prim1->dilateH = 1;
|
||||
prim1->dilateW = 1;
|
||||
prim1->channelOut = 3;
|
||||
convNode->primitive->value.value = prim1;
|
||||
convNode->name = "Conv2D";
|
||||
return convNode;
|
||||
}
|
||||
// conv2d has 3 inputs
|
||||
CNodeTptr BuildDepthwiseConv2D(int with_bias_flag) {
|
||||
auto convNode = std::make_unique<schema::CNodeT>();
|
||||
if (with_bias_flag) {
|
||||
convNode->inputIndex = {0, 1, 2};
|
||||
convNode->outputIndex = {3};
|
||||
} else {
|
||||
convNode->inputIndex = {0, 1};
|
||||
convNode->outputIndex = {2};
|
||||
}
|
||||
convNode->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
convNode->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
||||
auto prim1 = new schema::DepthwiseConv2DT;
|
||||
prim1->padMode = schema::PadMode_SAME;
|
||||
prim1->format = schema::Format_NHWC;
|
||||
prim1->strideH = 1;
|
||||
prim1->strideW = 1;
|
||||
prim1->kernelH = 3;
|
||||
prim1->kernelW = 3;
|
||||
prim1->dilateH = 1;
|
||||
prim1->dilateW = 1;
|
||||
prim1->channelIn = 1;
|
||||
prim1->channelMultiplier = 3;
|
||||
|
||||
convNode->primitive->value.value = prim1;
|
||||
convNode->name = "Conv2D";
|
||||
return convNode;
|
||||
}
|
||||
|
||||
MetaGraphTptr BuildGraph(schema::PrimitiveType conv_type, bool conv_with_bias) {
|
||||
auto meta_graph = std::make_shared<schema::MetaGraphT>();
|
||||
meta_graph->name = "graph";
|
||||
// conv node
|
||||
CNodeTptr convNode;
|
||||
if (conv_type == schema::PrimitiveType_Conv2D) {
|
||||
convNode = BuildConv2D(conv_with_bias);
|
||||
} else {
|
||||
convNode = BuildDepthwiseConv2D(conv_with_bias);
|
||||
}
|
||||
|
||||
meta_graph->nodes.emplace_back(std::move(convNode));
|
||||
|
||||
// scale_node weight bias
|
||||
auto scale_node = std::make_unique<schema::CNodeT>();
|
||||
if (conv_with_bias) {
|
||||
scale_node->inputIndex = {3, 4, 5};
|
||||
scale_node->outputIndex = {6};
|
||||
} else {
|
||||
scale_node->inputIndex = {2, 3, 4};
|
||||
scale_node->outputIndex = {5};
|
||||
}
|
||||
|
||||
scale_node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
scale_node->primitive->value.type = schema::PrimitiveType_Scale;
|
||||
auto prim2 = new schema::ScaleT;
|
||||
scale_node->primitive->value.value = prim2;
|
||||
scale_node->name = "scale";
|
||||
meta_graph->nodes.emplace_back(std::move(scale_node));
|
||||
|
||||
// input 0: data
|
||||
auto input0 = std::make_unique<schema::TensorT>();
|
||||
input0->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input0->format = schema::Format_NHWC;
|
||||
input0->dataType = TypeId::kNumberTypeFloat32;
|
||||
input0->dims = {1, 5, 5, 3};
|
||||
input0->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(input0));
|
||||
|
||||
// input 1: weight
|
||||
auto input1 = std::make_unique<schema::TensorT>();
|
||||
input1->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input1->format = schema::Format_KHWC;
|
||||
input1->dataType = TypeId::kNumberTypeFloat32;
|
||||
input1->dims = {8, 3, 3, 3};
|
||||
input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3);
|
||||
meta_graph->allTensors.emplace_back(std::move(input1));
|
||||
|
||||
if (conv_with_bias) {
|
||||
// input 00: bias
|
||||
auto input00 = std::make_unique<schema::TensorT>();
|
||||
input00->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input00->format = schema::Format_NHWC;
|
||||
input00->dataType = TypeId::kNumberTypeFloat32;
|
||||
input00->dims = {1, 5, 5, 3};
|
||||
input00->offset = -1;
|
||||
meta_graph->allTensors.emplace_back(std::move(input00));
|
||||
}
|
||||
|
||||
// conv output
|
||||
auto conv_output = std::make_unique<schema::TensorT>();
|
||||
conv_output->nodeType = schema::NodeType::NodeType_Parameter;
|
||||
conv_output->format = schema::Format_NHWC;
|
||||
conv_output->dataType = TypeId::kNumberTypeFloat32;
|
||||
conv_output->dims = {1, 5, 5, 8};
|
||||
meta_graph->allTensors.emplace_back(std::move(conv_output));
|
||||
|
||||
// scale weight input
|
||||
auto input2 = std::make_unique<schema::TensorT>();
|
||||
input2->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input2->format = schema::Format_NHWC;
|
||||
input2->dataType = TypeId::kNumberTypeFloat32;
|
||||
input2->dims = {1, 5, 5, 8};
|
||||
input2->data.resize(sizeof(float) * 8 * 5 * 5);
|
||||
meta_graph->allTensors.emplace_back(std::move(input2));
|
||||
|
||||
// scale bias input
|
||||
auto input3 = std::make_unique<schema::TensorT>();
|
||||
input3->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
input3->format = schema::Format_NHWC;
|
||||
input3->dataType = TypeId::kNumberTypeFloat32;
|
||||
input3->dims = {1, 5, 5, 8};
|
||||
input3->data.resize(sizeof(float) * 8 * 5 * 5);
|
||||
meta_graph->allTensors.emplace_back(std::move(input3));
|
||||
|
||||
// final scale output
|
||||
auto output = std::make_unique<schema::TensorT>();
|
||||
output->nodeType = schema::NodeType::NodeType_Parameter;
|
||||
output->format = schema::Format_NHWC;
|
||||
output->dataType = TypeId::kNumberTypeFloat32;
|
||||
output->dims = {1, 5, 5, 8};
|
||||
meta_graph->allTensors.emplace_back(std::move(output));
|
||||
if (conv_with_bias) {
|
||||
meta_graph->inputIndex = {0};
|
||||
meta_graph->outputIndex = {6};
|
||||
} else {
|
||||
meta_graph->inputIndex = {0};
|
||||
meta_graph->outputIndex = {5};
|
||||
}
|
||||
return meta_graph;
|
||||
}
|
||||
} // namespace
|
||||
TEST_F(ConvScaleFusionTest, TestConvScaleNode) {
|
||||
auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2D, true);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto anf_transform = new lite::AnfTransform();
|
||||
auto new_graph = anf_transform->Transform(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
|
||||
for (auto &cnode : new_meta_graph->nodes) {
|
||||
ASSERT_EQ(cnode->primitive->value.AsConv2D()->hasBias, true);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ConvScaleFusionTest, TestDeptiwiseConvScaleNode) {
|
||||
auto meta_graph = BuildGraph(schema::PrimitiveType_DepthwiseConv2D, false);
|
||||
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
|
||||
auto anf_transform = new lite::AnfTransform();
|
||||
auto new_graph = anf_transform->Transform(func_graph);
|
||||
ASSERT_NE(nullptr, new_graph);
|
||||
auto new_meta_graph = lite::Export(new_graph);
|
||||
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
|
||||
for (auto &cnode : new_meta_graph->nodes) {
|
||||
ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->hasBias, true);
|
||||
ASSERT_EQ(cnode->inputIndex.size(), 3);
|
||||
}
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -49,6 +49,8 @@ set(ANF_SRC
|
|||
${CCSRC_DIR}/pybind_api/export_flags.cc
|
||||
${CCSRC_DIR}/utils/context/context_extends.cc
|
||||
${CCSRC_DIR}/frontend/parallel/costmodel_context.cc
|
||||
${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc
|
||||
${CCSRC_DIR}/backend/optimizer/common/visit.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/graph_utils_extends.cc
|
||||
)
|
||||
|
||||
|
@ -75,9 +77,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/node_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/optimizer.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/pass_manager.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/pattern_engine.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/visit.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/gllo_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_biasadd_fusion.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_activation_fusion.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_transform_fusion.cc
|
||||
|
|
|
@ -90,7 +90,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// auto newGraph = anfTransform->Transform(graph);
|
||||
graph = anfTransform->Transform(graph);
|
||||
|
||||
CreateQuantizer(graph, flag);
|
||||
if (mQuantizer != nullptr) {
|
||||
|
|
|
@ -100,20 +100,20 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
// }
|
||||
|
||||
// fusion
|
||||
{
|
||||
Optimizer fusionOptimizer;
|
||||
fusionOptimizer.AddPass(new (std::nothrow) ConvBiasAddFusionPass());
|
||||
fusionOptimizer.AddPass(new (std::nothrow) ConvBNFusionPass());
|
||||
fusionOptimizer.AddPass(new (std::nothrow) ConvScaleFusionPass());
|
||||
fusionOptimizer.AddPass(new (std::nothrow) ConvReluFusionPass());
|
||||
fusionOptimizer.AddPass(new (std::nothrow) ConvRelu6FusionPass());
|
||||
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
status = fusionOptimizer.Run(graphDefT);
|
||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
// {
|
||||
// Optimizer fusionOptimizer;
|
||||
// fusionOptimizer.AddPass(new (std::nothrow) ConvBiasAddFusionPass());
|
||||
// fusionOptimizer.AddPass(new (std::nothrow) ConvBNFusionPass());
|
||||
// fusionOptimizer.AddPass(new (std::nothrow) ConvScaleFusionPass());
|
||||
// fusionOptimizer.AddPass(new (std::nothrow) ConvReluFusionPass());
|
||||
// fusionOptimizer.AddPass(new (std::nothrow) ConvRelu6FusionPass());
|
||||
// fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
// status = fusionOptimizer.Run(graphDefT);
|
||||
// if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||
// MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed";
|
||||
// return status;
|
||||
// }
|
||||
// }
|
||||
|
||||
// weight format trans
|
||||
if (ctx.formatTrans) {
|
||||
|
|
Loading…
Reference in New Issue