!3973 enable convert anf fusion pass and optimize

Merge pull request !3973 from zhengjun10/master
This commit is contained in:
mindspore-ci-bot 2020-08-06 19:54:35 +08:00 committed by Gitee
commit d6547a46e7
30 changed files with 1115 additions and 1027 deletions

View File

@ -13,15 +13,107 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/gllo/common/gllo_utils.h"
#include <vector>
#include <memory>
#include "src/gllo/common/utils.h"
#include "src/ir/primitive_t_value.h"
#include "frontend/operator/ops.h"
using PrimitiveTValuePtr = std::shared_ptr<mindspore::lite::PrimitiveTValue>;
namespace mindspore {
namespace opt {
namespace {
constexpr auto kAnfPrimitiveIndex = 0;
bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
}
bool IsRealKernel(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
// parameter and value node is not a real kernel too
if (!node->isa<CNode>()) {
return true;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().empty()) {
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString();
}
auto input = cnode->inputs()[0];
bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) ||
IsPrimitive(input, prim::kPrimTensorSummary) ||
IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) ||
IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial);
return !is_virtual_node;
}
ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
if (utils::isa<int>(sexp)) {
return NewValueNode(utils::cast<int>(sexp));
}
if (utils::isa<float>(sexp)) {
return NewValueNode(utils::cast<float>(sexp));
}
if (utils::isa<bool>(sexp)) {
return NewValueNode(utils::cast<bool>(sexp));
}
if (utils::isa<ValuePtr>(sexp)) {
return NewValueNode(utils::cast<ValuePtr>(sexp));
}
return nullptr;
}
CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
if (utils::isa<FuncGraphPtr>(graph)) {
return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
}
if (utils::isa<VarPtr>(graph)) {
return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
}
return nullptr;
}
VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
if (utils::isa<VarPtr>(graph)) {
MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
}
if (utils::isa<FuncGraphPtr>(graph)) {
MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
}
MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
return nullptr;
}
AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
bool multigraph) {
MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
std::vector<AnfNodePtr> input_nodes;
const auto &tuple = utils::cast<VectorRef>(sexp);
if (multigraph && utils::isa<VarPtr>(graph)) {
for (auto &x : tuple) {
AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
input_nodes.push_back(node);
}
VarPtr var_ptr = utils::cast<VarPtr>(graph);
return std::make_shared<CNode>(input_nodes, var_ptr);
}
for (auto &x : tuple) {
AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
input_nodes.push_back(node);
}
return CreateCNodeWithGraph(input_nodes, graph);
}
} // namespace
bool AnfEqual(const BaseRef &a, const BaseRef &b) {
if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
auto a_node = utils::cast<AnfNodePtr>(a);
@ -64,15 +156,15 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) {
}
if (utils::isa<lite::PrimitiveTValue>(a_value_ptr) && utils::isa<lite::PrimitiveTValue>(b_value_ptr)) {
auto a_obj = (lite::PrimitiveTValue *)(a_value_ptr.get());
auto b_obj = (lite::PrimitiveTValue *)(b_value_ptr.get());
auto a_obj = (lite::PrimitiveTValue *) (a_value_ptr.get());
auto b_obj = (lite::PrimitiveTValue *) (b_value_ptr.get());
return (*a_obj) == (*b_obj);
} else {
return (*a_value_ptr) == (*b_value_ptr);
}
}
}
if (a.m_ptr->isa<lite::PrimitiveTValue>()) {
if (a.m_ptr->isa<lite::PrimitiveTValue>() && b.m_ptr->isa<lite::PrimitiveTValue>()) {
auto a_value_node_ptr = a.m_ptr->cast<PrimitiveTValuePtr>();
auto b_value_node_ptr = b.m_ptr->cast<PrimitiveTValuePtr>();
return a_value_node_ptr->GetPrimitiveT()->value.type == b_value_node_ptr->GetPrimitiveT()->value.type;
@ -89,70 +181,8 @@ bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
return a.type() == b.type();
}
namespace {
ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
if (utils::isa<int>(sexp)) {
return NewValueNode(utils::cast<int>(sexp));
}
if (utils::isa<float>(sexp)) {
return NewValueNode(utils::cast<float>(sexp));
}
if (utils::isa<bool>(sexp)) {
return NewValueNode(utils::cast<bool>(sexp));
}
if (utils::isa<ValuePtr>(sexp)) {
return NewValueNode(utils::cast<ValuePtr>(sexp));
}
return nullptr;
}
CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
if (utils::isa<FuncGraphPtr>(graph)) {
return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
}
if (utils::isa<VarPtr>(graph)) {
return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
}
return nullptr;
}
VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
if (utils::isa<VarPtr>(graph)) {
// MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
}
if (utils::isa<FuncGraphPtr>(graph)) {
// MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
}
MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
return nullptr;
}
AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
bool multigraph) {
// MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
std::vector<AnfNodePtr> input_nodes;
const auto &tuple = utils::cast<VectorRef>(sexp);
if (multigraph && utils::isa<VarPtr>(graph)) {
for (auto &x : tuple) {
AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
input_nodes.push_back(node);
}
VarPtr var_ptr = utils::cast<VarPtr>(graph);
return std::make_shared<CNode>(input_nodes, var_ptr);
}
for (auto &x : tuple) {
AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
input_nodes.push_back(node);
}
return CreateCNodeWithGraph(input_nodes, graph);
}
} // namespace
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
// MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
MS_EXCEPTION_IF_NULL(primitive_vars);
if (utils::isa<VectorRef>(sexp)) {
return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
@ -176,6 +206,38 @@ AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap
return value_node;
}
bool IsRealCNodeKernel(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
// parameter and value node is not a real cnode kernel
if (!node->isa<CNode>()) {
return false;
}
// return considered as a real node
if (CheckPrimitiveType(node, prim::kPrimReturn)) {
return true;
}
return IsRealKernel(node);
}
bool IsGraphKernel(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
// graph kernel should be a real cnode kernel.
if (!IsRealCNodeKernel(node)) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input = cnode->input(kAnfPrimitiveIndex);
// graph kernel should has func_graph as first input.
if (!IsValueNode<FuncGraph>(input)) {
return false;
}
auto func_graph = GetValueNode<FuncGraphPtr>(input);
MS_EXCEPTION_IF_NULL(func_graph);
return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
}
void CheckIfFuncGraphIsNull(const FuncGraphPtr &graph) {
if (graph == nullptr) {
MS_LOG(EXCEPTION) << "The graph is null.";

View File

@ -14,22 +14,21 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_
#define MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_
#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_
#define MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_
#include <mindspore/lite/src/ir/primitive_t_value.h>
#include <memory>
#include "src/ir/primitive_t_value.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "src/common/utils.h"
#include "src/gllo/common/pattern_engine.h"
#include "backend/optimizer/common/pattern_engine.h"
#include "schema/inner/model_generated.h"
#include "src/param_value_lite.h"
using PrimitiveTValuePtr = std::shared_ptr<mindspore::lite::PrimitiveTValue>;
namespace mindspore {
namespace opt {
bool AnfEqual(const BaseRef &a, const BaseRef &b);
bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b);
@ -37,6 +36,10 @@ bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b);
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
bool multigraph = false);
bool IsRealCNodeKernel(const AnfNodePtr &node);
bool IsGraphKernel(const AnfNodePtr &node);
void CheckIfFuncGraphIsNull(const FuncGraphPtr &graph);
void CheckIfAnfNodeIsNull(const AnfNodePtr &node);
@ -61,4 +64,4 @@ bool IsParamNode(const BaseRef &n);
bool IsConvNode(const BaseRef &n);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_UTILS_H_
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/gllo/common/node_pass.h"
#include "backend/optimizer/common/node_pass.h"
#include <unordered_set>
#include <deque>
@ -22,6 +22,7 @@
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/manager.h"
#include "src/gllo/common/gllo_utils.h"
namespace mindspore {
namespace opt {
@ -54,6 +55,9 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(const_func_graph);
todo.push_back(const_func_graph->output());
} else if (new_node && new_node->isa<CNode>()) {
if (IsGraphKernel(new_node)) {
todo.push_back(new_node);
}
auto cnode = new_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto inputs = cnode->inputs();

View File

@ -1,36 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_NODE_PASS_H_
#define MINDSPORE_LITE_SRC_PASS_COMMON_NODE_PASS_H_
#include <string>
#include <memory>
#include "src/gllo/common/pass.h"
namespace mindspore {
namespace opt {
// @brief ANF Node level optimization base pass
class NodePass : public Pass {
public:
explicit NodePass(const std::string &name) : Pass(name) {}
~NodePass() override = default;
bool Run(const FuncGraphPtr &func_graph) final;
virtual AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0;
};
using NodePassPtr = std::shared_ptr<NodePass>;
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_NODE_PASS_H_

View File

@ -23,8 +23,7 @@
#include <utility>
#include <initializer_list>
#include "src/gllo/common/pass_manager.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/common/pass_manager.h"
#include "ir/manager.h"
namespace mindspore {

View File

@ -26,9 +26,9 @@
#include "ir/graph_utils.h"
#include "src/common/utils.h"
#include "src/gllo/common/pass_manager.h"
#include "src/gllo/common/pattern_engine.h"
#include "src/gllo/common/utils.h"
#include "backend/optimizer/common/pass_manager.h"
#include "backend/optimizer/common/pattern_engine.h"
#include "src/gllo/common/gllo_utils.h"
namespace mindspore {
namespace opt {

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/gllo/common/pass_manager.h"
#include "backend/optimizer/common/pass_manager.h"
#include <sys/time.h>
#include <unordered_set>

View File

@ -1,61 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_PASS_MANAGER_H_
#define MINDSPORE_LITE_SRC_PASS_COMMON_PASS_MANAGER_H_
#include <utility>
#include <vector>
#include <string>
#include <memory>
#include "src/gllo/common/pass.h"
#include "src/gllo/common/node_pass.h"
namespace mindspore {
namespace opt {
// @brief For optimization passes management
class PassManager {
public:
explicit PassManager(const std::string &name = "pm", bool run_only_once = true)
: name_(name), passes_{}, run_only_once_(run_only_once) {}
virtual ~PassManager() = default;
// Get all the passes added by AddPass
const std::vector<PassPtr> &Passes() const;
// Add graph pass, the pass object will be freed when pass manager freed.
void AddPass(const PassPtr &pass);
// Run passes added in pass manager on the input graph
// @param [inout] graph The graph to be optimized
// @return true, graph changed
// @return false, graph not changed
bool Run(const FuncGraphPtr &func_graph) const;
// Run the given graph passes on the input graph
// @param [inout] graph The graph to be optimized
// @param [in] passes The given graph passes
// @return true, graph changed
// @return false, graph not changed
bool Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr> &passes) const;
std::string name() const { return name_; }
private:
const std::string name_;
std::vector<PassPtr> passes_;
bool run_only_once_;
};
using PassManagerPtr = std::shared_ptr<PassManager>;
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_PASS_MANAGER_H_

View File

@ -1,365 +0,0 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/gllo/common/pattern_engine.h"
#include <exception>
#include <iostream>
#include <functional>
#include <iterator>
#include "ir/func_graph.h"
#include "mindspore/core/ir/primitive.h"
#include "utils/info.h"
#include "ir/anf.h"
#include "utils/convert_utils_base.h"
#include "utils/overload.h"
namespace mindspore {
static int GetNextTag() {
static int kID = 0;
return kID++;
}
void Var::EnsureTag() {
if (tag_.length() == 0) {
std::ostringstream buffer;
buffer << "_" << GetNextTag();
tag_ = buffer.str();
}
}
bool operator==(const VarPtr &lhs, const VarPtr &rhs) {
if (lhs->isa<CondVar>() && rhs->isa<CondVar>()) {
CondVarPtr v1 = dyn_cast<CondVar>(lhs);
CondVarPtr v2 = dyn_cast<CondVar>(rhs);
return *v1 == *v2;
}
if (lhs->isa<SeqVar>() && rhs->isa<SeqVar>()) {
SVarPtr v1 = dyn_cast<SeqVar>(lhs);
SVarPtr v2 = dyn_cast<SeqVar>(rhs);
return *v1 == *v2;
}
return (*lhs == *rhs);
}
std::string SeqVar::ToString() const {
std::ostringstream buffer;
buffer << "SeqVar(" << tag() << ", " << subvar_->ToString() << ")";
return buffer.str();
}
std::ostream &operator<<(std::ostream &os, const VarPtr &var) {
if (var == nullptr) {
os << "";
} else {
os << var->ToString();
}
return os;
}
template <>
std::ostream &operator<<<VarPtr, BaseRef>(std::ostream &os, const Equiv &equiv) {
os << "[Equiv]"
<< "\n";
for (auto &equiv_item : equiv) {
auto k = equiv_item.first;
os << k << ":";
BaseRef x = equiv_item.second;
if (utils::isa<AnfNodePtr>(x)) {
auto node = utils::cast<AnfNodePtr>(x);
os << "TypeString[" << node->type_name() << "]";
if (IsValueNode<FuncGraph>(node)) {
os << "IsValueNodeGraph ";
}
os << "type " << node->type_name();
if (node->isa<ValueNode>()) {
os << " value " << GetValueNode(node);
}
os << " addr: " << node;
} else if (utils::isa<Named>(x)) {
os << "Named " << x.ToString().c_str();
} else if (utils::isa<VarPtr>(x)) {
os << "TypeString[Var]";
os << utils::cast<VarPtr>(x);
} else if (utils::isa<FuncGraphPtr>(x)) {
os << "TypeString[Graph]";
}
os << "\n";
}
return os;
}
static BaseRef GetVar(const BaseRef &x) {
// MS_LOG(DEBUG) << "getVar start :%s" + x.ToString();
if (utils::isa<AnfNodePtr>(x)) {
auto node = utils::cast<AnfNodePtr>(x);
// MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]";
if (node->isa<VarNode>()) {
// MS_LOG(DEBUG) << "IsVarNode " + node->cast<VarNodePtr>()->var_->ToString();
return node->cast<VarNodePtr>()->var_;
}
// if (node->isa<ValueNode>()) {
// MS_LOG(DEBUG) << "value " + GetValueNode(node)->ToString() + " addr: " + node->ToString();
// } else {
// MS_LOG(DEBUG) << "type " + node->type_name();
// }
// } else if (utils::isa<Named>(x)) {
// MS_LOG(DEBUG) << "Named " + x.ToString();
// } else if (utils::isa<VectorRef>(x)) {
// MS_LOG(DEBUG) << "VectorRef";
// } else if (utils::isa<VarPtr>(x)) {
// MS_LOG(DEBUG) << "TypeString[Var] " + x.ToString();
}
// MS_LOG(DEBUG) << "GetVar end: " + x.ToString();
return x;
}
EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) {
MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString();
MS_EXCEPTION_IF_NULL(equiv);
if (utils::isa<VarPtr>(pattern)) {
VarPtr var = utils::cast<VarPtr>(pattern);
if (var->matches(expr)) {
(*equiv)[var] = expr;
MS_LOG(DEBUG) << "pattern is var match: " + pattern.ToString() + ", " + expr.ToString();
return equiv;
}
}
return nullptr;
}
bool PatternEngine::ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern,
VectorRef *const values_expr) const {
MS_EXCEPTION_IF_NULL(values_expr);
if (utils::isa<SeqPtr>(pattern_ref)) {
*values_pattern = pattern_ref;
*values_expr = expr_ref;
return true;
}
return false;
}
bool PatternEngine::ToVector(const BaseRef &pattern_ref, const BaseRef &expr_ref, VectorRef *const values_pattern,
VectorRef *const values_expr) const {
MS_EXCEPTION_IF_NULL(values_expr);
// visitor to visite the list
auto appender_pattern = [](VectorRef &values) {
std::function<BaseRef(const BaseRef &)> fn = [&](const BaseRef &u) {
values.push_back(GetVar(u));
return u;
};
return fn;
};
visitor_->SetFn(appender_pattern(*values_pattern));
// MS_LOG(DEBUG) << "visit pattern_ref";
bool success = visitor_->Visit(pattern_ref, nullptr);
if (!success) {
return false;
}
auto appender_expr = [](VectorRef &values) {
std::function<BaseRef(const BaseRef &)> fn = [&](const BaseRef &u) {
values.push_back(u);
return u;
};
return fn;
};
visitor_->SetFn(appender_expr(*values_expr));
// MS_LOG(DEBUG) << "visit expr_ref";
return visitor_->Visit(expr_ref, nullptr);
}
static int GetSVarStartIndex(const VectorRef &values) {
int index = -1;
int count = 0;
for (auto &value : values) {
if (utils::isa<VarPtr>(value) && utils::cast<VarPtr>(value)->isa<SeqVar>()) {
if (index != -1) {
// MS_LOG(DEBUG) << "Multiple SVars in sequence";
return kInvalidVarIndex;
}
index = count;
}
count++;
}
return index;
}
void UpdateEquivMap(const VectorRef &values_pattern, const BaseRef &expr_ref, const PrimitiveVarMap &primitive_vars,
EquivPtr equiv) {
if (equiv == nullptr || values_pattern.empty() || !utils::isa<AnfNodePtr>(values_pattern[0]) ||
!utils::isa<AnfNodePtr>(expr_ref)) {
return;
}
auto real_node = utils::cast<AnfNodePtr>(expr_ref);
MS_EXCEPTION_IF_NULL(real_node);
if (!real_node->isa<CNode>()) {
return;
}
auto prim_node = utils::cast<AnfNodePtr>(values_pattern[0]);
MS_EXCEPTION_IF_NULL(prim_node);
if (!IsValueNode<Primitive>(prim_node)) {
return;
}
ValuePtr value = GetValueNode(prim_node);
MS_EXCEPTION_IF_NULL(value);
auto prim = value->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(prim);
auto iter = primitive_vars.find(prim);
if (iter == primitive_vars.end()) {
return;
}
(*equiv)[iter->second] = real_node;
}
EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr,
const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const {
int svar_index = GetSVarStartIndex(values_pattern);
if (svar_index == kInvalidVarIndex) {
return nullptr;
}
size_t values_pattern_len = values_pattern.size();
size_t values_expr_len = values_expr.size();
if (svar_index == -1) {
if (values_pattern_len != values_expr_len) {
// MS_LOG(DEBUG) << "Structures of differing size: pattern len " << values_pattern_len << ",
// expr len " << values_expr_len;
return nullptr;
}
}
if (values_expr_len < values_pattern_len - 1) {
MS_LOG(DEBUG) << "invalid size: pattern len " << values_pattern_len << ", expr len " << values_expr_len;
return nullptr;
}
size_t diff = values_expr_len - values_pattern_len + 1;
for (size_t i = 0; i < values_pattern_len; i++) {
size_t expr_i = i;
if (svar_index != -1 && i == IntToSize(svar_index)) {
auto seq =
std::vector<BaseRef>(values_expr.begin() + svar_index, values_expr.begin() + svar_index + SizeToInt(diff));
equiv = Match(values_pattern[svar_index], seq, primitive_vars, equiv);
} else {
if (svar_index != -1 && i > IntToSize(svar_index)) {
expr_i = i + diff - 1;
}
equiv = Match(values_pattern[i], values_expr[expr_i], primitive_vars, equiv);
}
if (equiv == nullptr) {
return nullptr;
}
}
return equiv;
}
EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars,
EquivPtr equiv) const {
MS_LOG(DEBUG) << "-----[in Match]";
// MS_LOG(DEBUG) << "GetVar w";
BaseRef pattern_ref = GetVar(pattern);
// MS_LOG(DEBUG) << "GetVar v";
BaseRef expr_ref = expr;
if (equiv == nullptr) {
MS_LOG(EXCEPTION) << "Equiv pointer is null";
}
MS_LOG(DEBUG) << "Pattern ref " + pattern_ref.ToString() + ", expr ref" + expr_ref.ToString();
// 1. if pattern_ref is var and already in equiv, replace it.
if (utils::isa<VarPtr>(pattern_ref)) {
VarPtr var = utils::cast<VarPtr>(pattern_ref);
auto iter = equiv->find(var);
if (iter != equiv->end()) {
pattern_ref = iter->second;
}
}
// 2. check equal
if (eq_(pattern_ref, expr_ref)) {
return equiv;
}
// 3. match var
EquivPtr ret_equiv = MatchOnVar(pattern_ref, expr_ref, equiv);
if (ret_equiv) {
return ret_equiv;
}
// 4. here the type can be std:vector, std:list,
// or cnode.
if (!type_eq_(pattern_ref, expr_ref)) {
MS_LOG(DEBUG) << "Type mismatch";
return nullptr;
}
// 5. transfer the Containers by visitor to std::vector
VectorRef values_pattern;
VectorRef values_expr;
if (!ToVector(pattern_ref, expr_ref, &values_pattern, &values_expr)) {
return nullptr;
}
// 6. if any svar in both side, find the SeqVar index,
// try to pack the Var s in std::vector to a Seq and match elements one by one.
// check svar
equiv = AlignSVar(values_pattern, values_expr, primitive_vars, equiv);
UpdateEquivMap(values_pattern, expr_ref, primitive_vars, equiv);
return equiv;
}
BaseRef PatternEngine::Replace(const BaseRef &pattern, const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(equiv);
MS_LOG(DEBUG) << "-----[in Replace]";
BaseRef ref = GetVar(pattern);
BaseRef out;
bool is_match = false;
// w is var
if (utils::isa<VarPtr>(ref)) {
const VarPtr &var = utils::cast<VarPtr>(ref);
auto iter = equiv->find(var);
if (iter != equiv->end()) {
out = iter->second;
is_match = true;
}
}
if (is_match) {
return out;
}
// visitor to visit the list
std::function<BaseRef(BaseRef)> fn = [&, this, equiv](const BaseRef &u) { return Replace(u, equiv); };
visitor_->SetFn(fn);
BaseRef visit_out;
if (!visitor_->Visit(pattern, &visit_out)) {
return pattern;
}
return visit_out;
}
} // namespace mindspore

View File

@ -1,203 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_PATTERN_ENGINE_H_
#define MINDSPORE_LITE_SRC_PASS_COMMON_PATTERN_ENGINE_H_
#include <string>
#include <sstream>
#include <memory>
#include <vector>
#include <unordered_set>
#include <unordered_map>
#include <initializer_list>
#include <iostream>
#include <algorithm>
#include <map>
#include <stdexcept>
#include <list>
#include <utility>
#include "src/gllo/common/visit.h"
#include "mindspore/core/base/base.h"
#include "utils/log_adapter.h"
#include "base/base_ref.h"
namespace mindspore {
class CondVar;
class SeqVar;
using CondVarPtr = std::shared_ptr<CondVar>;
using SVarPtr = std::shared_ptr<SeqVar>;
const int kInvalidVarIndex = -2;
using ConditionFunc = std::function<bool(const BaseRef &)>;
// Base wildcard variable which could match any anf node.
class Var : public Base {
friend class VarHasher;
public:
explicit Var(std::string tag = "") : tag_(std::move(tag)), primitive_(nullptr) { EnsureTag(); }
explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) {
EnsureTag();
}
Var(const Var &other) : Base(other), tag_(other.tag_) {}
virtual Var &operator=(const Var &other) {
if (&other == this) {
return *this;
}
this->tag_ = other.tag_;
return *this;
}
~Var() override = default;
MS_DECLARE_PARENT(Var, Base);
virtual bool matches(const BaseRef &) { return true; }
virtual bool operator==(const Var &other) const { return tag_ == other.tag_; }
bool operator!=(const Var &other) const { return !(&other == this); }
std::string tag() const { return tag_; }
PrimitivePtr primitive() const { return primitive_; }
std::string ToString() const override {
std::ostringstream buffer;
buffer << "Var(" << tag_ << ")";
return buffer.str();
}
std::size_t hash() const override { return std::hash<std::string>()(tag_); }
protected:
void EnsureTag();
std::string tag_;
PrimitivePtr primitive_;
};
// VarNode means variable node, a subclass of AnfNode
class VarNode : public AnfNode {
public:
VarNode(const VarPtr &value, const FuncGraphPtr &func_graph) : AnfNode(func_graph), var_(value) {}
~VarNode() override = default;
MS_DECLARE_PARENT(VarNode, AnfNode);
const VarPtr var_;
};
using VarNodePtr = std::shared_ptr<VarNode>;
class VarHasher {
public:
std::size_t operator()(const Var &var) const { return var.hash(); }
};
// Condition Var, match an anf node when condition function return true.
class CondVar : public Var {
public:
explicit CondVar(const ConditionFunc &cond) : cond_fn_(cond) {}
~CondVar() override = default;
MS_DECLARE_PARENT(CondVar, Var);
bool matches(const BaseRef &value) override {
// MS_LOG(DEBUG) << "CondVarPtr match: " + value.ToString();
if (utils::isa<Var>(value)) {
return false;
}
return cond_fn_(value);
}
ConditionFunc cond_fn_;
};
using Seq = VectorRef;
using SeqPtr = std::shared_ptr<Seq>;
// Sequence Var which could match multiple consecutive input nodes of a CNode.
class SeqVar : public Var {
public:
SeqVar() : subvar_(std::make_shared<Var>()) {}
~SeqVar() override = default;
MS_DECLARE_PARENT(SeqVar, Var);
explicit SeqVar(const VarPtr subvar) : subvar_(subvar) {}
bool matches(const BaseRef &value) override {
// match Seq.
if (utils::isa<Seq>(value)) {
const Seq &seq = utils::cast<Seq>(value);
return std::all_of(seq.begin(), seq.end(), [this](const BaseRef &v) {
auto eq = subvar_->matches(v);
return eq;
});
}
return false;
}
bool operator==(const SeqVar &other) const { return *subvar_ == *other.subvar_; }
std::string ToString() const override;
private:
VarPtr subvar_;
};
bool operator==(const VarPtr &lhs, const VarPtr &rhs);
inline bool operator!=(const VarPtr &lhs, const VarPtr &rhs) { return !(lhs == rhs); }
std::ostream &operator<<(std::ostream &os, const VarPtr &var);
using Equiv = std::map<VarPtr, BaseRef>;
using EquivPtr = std::shared_ptr<Equiv>;
using PrimitiveVarMap = std::unordered_map<PrimitivePtr, VarPtr>;
using PrimitiveVarMapPtr = std::shared_ptr<PrimitiveVarMap>;
inline bool DefaultTypeEq(const BaseRef &x, const BaseRef &y) { return x.type() == y.type(); }
class PatternEngine {
public:
PatternEngine(const std::shared_ptr<Visitor> &visitor,
const std::function<bool(const BaseRef &, const BaseRef &)> &eq,
const std::function<bool(const BaseRef &, const BaseRef &)> &type_eq = DefaultTypeEq)
: visitor_(visitor), eq_(eq), type_eq_(type_eq) {}
~PatternEngine() = default;
EquivPtr Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars,
EquivPtr equiv) const;
// Replace pattern with equivalent
BaseRef Replace(const BaseRef &pattern, const EquivPtr &equiv) const;
private:
EquivPtr AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr,
const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const;
bool ToVector(const BaseRef &pattern, const BaseRef &expr, VectorRef *const values_pattern,
VectorRef *const values_expr) const;
bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern,
VectorRef *const values_expr) const;
std::shared_ptr<Visitor> visitor_;
std::function<bool(const BaseRef &, const BaseRef &)> eq_;
std::function<bool(const BaseRef &, const BaseRef &)> type_eq_;
};
} // namespace mindspore
namespace std {
using mindspore::ERROR;
using mindspore::LogStream;
using mindspore::NoExceptionType;
template <>
struct hash<mindspore::VarPtr> {
std::size_t operator()(const mindspore::VarPtr var) const {
if (var == nullptr) {
MS_LOG(ERROR) << "Invalid var ptr";
return 0;
}
return std::hash<std::string>{}(var->tag());
}
};
} // namespace std
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_PATTERN_ENGINE_H_

View File

@ -1,165 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <memory>
#include <algorithm>
#include <utility>
#include "src/gllo/common/visit.h"
#include "src/gllo/common/pattern_engine.h"
#include "utils/any.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "utils/log_adapter.h"
namespace mindspore {
bool CheckIfNeedExpand(const std::vector<BaseRef> &list) {
return std::any_of(list.begin(), list.end(), [](const BaseRef &any) { return utils::isa<Seq>(any); });
}
std::shared_ptr<VectorRef> ExpandList(const std::vector<BaseRef> &list) {
std::shared_ptr<VectorRef> new_list = std::make_shared<VectorRef>();
for (auto &item : list) {
if (utils::isa<Seq>(item)) {
const Seq &seq = utils::cast<Seq>(item);
new_list->insert(new_list->end(), seq.begin(), seq.end());
} else {
new_list->push_back(item);
}
}
return new_list;
}
bool DefaultVisitor::Visit(const VectorRef &v_any, BaseRef *const visit_out) const {
std::vector<BaseRef> out;
(void)std::transform(v_any.begin(), v_any.end(), std::back_inserter(out),
[this](const BaseRef &item) { return fn_(item); });
if (visit_out != nullptr) {
*visit_out = ExpandList(out);
}
return true;
}
bool DefaultVisitor::Visit(const BaseRef &any, BaseRef *const visit_out) const {
if (utils::isa<Seq>(any)) {
return Visit(utils::cast<Seq>(any), visit_out);
} else if (utils::isa<AnfNodePtr>(any)) {
auto nodeptr = utils::cast<AnfNodePtr>(any);
AnfNodePtr output;
AnfNodePtr *p_output = &output;
if (visit_out == nullptr) {
p_output = nullptr;
}
Visit(nodeptr, fn_, p_output);
if (visit_out != nullptr) {
*visit_out = output;
}
return true;
}
MS_LOG(DEBUG) << "VisitError, not support type to Visit: " + any.ToString();
return false;
}
void DefaultVisitor::Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr *output) const {
if (node->isa<CNode>()) {
Visit(node->cast<CNodePtr>(), fn, output);
return;
}
if (node->isa<ValueNode>()) {
Visit(node->cast<ValueNodePtr>(), fn, output);
return;
}
if (output != nullptr) {
*output = node;
}
}
void DefaultVisitor::Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr *output) const {
// if output is nullptr, it's not required to make the new CNode node.
if (output == nullptr) {
for (auto &inp : cnode->inputs()) {
(void)fn(inp);
}
if (cnode->func_graph() != nullptr) {
(void)fn(cnode->func_graph());
} else {
(void)fn(cnode->func_graph_as_var());
}
return;
}
std::vector<AnfNodePtr> new_inputs;
std::vector<BaseRef> after_cnode_fn;
std::shared_ptr<VectorRef> out;
(void)std::transform(cnode->inputs().begin(), cnode->inputs().end(), std::back_inserter(after_cnode_fn), fn);
if (CheckIfNeedExpand(after_cnode_fn)) {
out = ExpandList(after_cnode_fn);
}
std::vector<BaseRef> &outs = after_cnode_fn;
if (out != nullptr) {
outs = out->elements();
}
for (auto &any_item : outs) {
if (!utils::isa<AnfNodePtr>(any_item)) {
MS_LOG(EXCEPTION) << "VisitError, fn not return the same type AnfNodePtr";
}
new_inputs.push_back(utils::cast<AnfNodePtr>(any_item));
}
BaseRef any_fg;
AnfNodePtr new_cnode = nullptr;
if (cnode->func_graph() != nullptr) {
any_fg = fn(cnode->func_graph());
if (!utils::isa<FuncGraphPtr>(any_fg)) {
MS_LOG(EXCEPTION) << "VisitError, fn not return the same type FuncGraphPtr";
}
new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<FuncGraphPtr>(any_fg));
} else {
any_fg = fn(cnode->func_graph_as_var());
if (utils::isa<VarPtr>(any_fg)) {
new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<VarPtr>(any_fg));
} else if (utils::isa<FuncGraphPtr>(any_fg)) {
new_cnode = std::make_shared<CNode>(new_inputs, utils::cast<FuncGraphPtr>(any_fg));
} else {
MS_LOG(EXCEPTION) << "VisitError, fn not return VarPtr or FuncGraphPtr";
}
}
new_cnode->set_abstract(cnode->abstract());
*output = new_cnode;
}
void DefaultVisitor::Visit(const ValueNodePtr &vnode, const VisitFn &fn, AnfNodePtr *output) const {
const BaseRef &value = utils::cast<ValuePtr>(fn(vnode->value()));
if (utils::isa<ValuePtr>(value)) {
if (output != nullptr) {
auto ct = NewValueNode(utils::cast<ValuePtr>(value));
ct->set_abstract(vnode->abstract());
*output = ct;
}
return;
}
MS_LOG(EXCEPTION) << "Visit result is not ValuePtr.";
}
} // namespace mindspore

View File

@ -1,59 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LIFT_SRC_PASS_COMMON_VISIT_H_
#define MINDSPORE_LIFT_SRC_PASS_COMMON_VISIT_H_
#include <unordered_map>
#include <stdexcept>
#include <list>
#include <vector>
#include <string>
#include <memory>
#include "mindspore/core/base/base.h"
#include "base/base_ref.h"
namespace mindspore {
using VisitFn = std::function<BaseRef(const BaseRef &)>;
class Visitor {
public:
virtual void SetFn(VisitFn fn) = 0;
virtual bool Visit(const BaseRef &e, BaseRef *out) const = 0;
virtual bool Visit(const VectorRef &e, BaseRef *out) const = 0;
virtual ~Visitor() = default;
};
class DefaultVisitor : public Visitor {
public:
DefaultVisitor() : fn_(nullptr) {}
~DefaultVisitor() override = default;
void SetFn(VisitFn fn) override { fn_ = fn; };
bool Visit(const VectorRef &e, BaseRef *out) const override;
bool Visit(const BaseRef &e, BaseRef *out) const override;
void Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr *output) const;
void Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr *output) const;
void Visit(const ValueNodePtr &vnode, const VisitFn &fn, AnfNodePtr *output) const;
VisitFn fn_;
};
std::shared_ptr<VectorRef> ExpandList(const std::vector<BaseRef> &list);
bool CheckIfNeedExpand(const std::vector<BaseRef> &list);
} // namespace mindspore
#endif // MINDSPORE_LIFT_SRC_PASS_COMMON_VISIT_H_

View File

@ -14,12 +14,12 @@
* limitations under the License.
*/
#include "mindspore/lite/src/gllo/fusion/conv_activation_fusion.h"
#include "src/gllo/fusion/conv_activation_fusion.h"
#include <memory>
#include "mindspore/lite/schema/inner/model_generated.h"
#include "mindspore/lite/src/ir/primitive_t_value.h"
#include "mindspore/ccsrc/utils/utils.h"
#include "mindspore/lite/src/gllo/common/utils.h"
#include "schema/inner/model_generated.h"
#include "src/ir/primitive_t_value.h"
#include "utils/utils.h"
#include "src/gllo/common/gllo_utils.h"
namespace mindspore::opt {
namespace {

View File

@ -18,7 +18,7 @@
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_
#include <string>
#include "mindspore/lite/src/gllo/common/optimizer.h"
#include "src/gllo/common/optimizer.h"
namespace mindspore {
namespace opt {

View File

@ -13,13 +13,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindspore/lite/src/gllo/fusion/conv_biasadd_fusion.h"
#include <mindspore/lite/src/param_value_lite.h>
#include "src/gllo/fusion/conv_biasadd_fusion.h"
#include <memory>
#include "mindspore/lite/schema/inner/model_generated.h"
#include "mindspore/lite/src/ir/primitive_t_value.h"
#include "mindspore/ccsrc/utils/utils.h"
#include "mindspore/lite/src/gllo/common/utils.h"
#include "src/param_value_lite.h"
#include "schema/inner/model_generated.h"
#include "src/ir/primitive_t_value.h"
#include "utils/utils.h"
#include "src/gllo/common/gllo_utils.h"
#include "securec/include/securec.h"
namespace mindspore::opt {
@ -142,7 +142,7 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons
CheckIfCNodeIsNull(conv_node);
GenConvNewBias(func_graph, conv_node, add_node);
auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveTValue>>(conv_node->input(0));
MS_ASSERT(primitiveT_value);
MS_ASSERT(primitiveT_value != nullptr);
auto type = primitiveT_value->GetPrimitiveT()->value.type;
if (type == schema::PrimitiveType_Conv2D) {
primitiveT_value->GetPrimitiveT()->value.AsConv2D()->hasBias = true;

View File

@ -17,7 +17,7 @@
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_
#include "mindspore/lite/src/gllo/common/optimizer.h"
#include "src/gllo/common/optimizer.h"
namespace mindspore {
namespace opt {

View File

@ -14,13 +14,13 @@
* limitations under the License.
*/
#include "mindspore/lite/src/gllo/fusion/conv_bn_fusion.h"
#include <mindspore/lite/src/param_value_lite.h>
#include "src/gllo/fusion/conv_bn_fusion.h"
#include <memory>
#include "mindspore/lite/schema/inner/model_generated.h"
#include "mindspore/lite/src/ir/primitive_t_value.h"
#include "mindspore/ccsrc/utils/utils.h"
#include "mindspore/lite/src/gllo/common/utils.h"
#include "src/param_value_lite.h"
#include "schema/inner/model_generated.h"
#include "src/ir/primitive_t_value.h"
#include "utils/utils.h"
#include "src/gllo/common/gllo_utils.h"
#include "securec/include/securec.h"
namespace mindspore::opt {

View File

@ -17,7 +17,7 @@
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_
#include "mindspore/lite/src/gllo/fusion/conv_transform_fusion.h"
#include "src/gllo/fusion/conv_transform_fusion.h"
namespace mindspore::opt {
class ConvBatchNormFusion : public ConvTransformFusion {

View File

@ -14,13 +14,13 @@
* limitations under the License.
*/
#include "mindspore/lite/src/gllo/fusion/conv_scale_fusion.h"
#include "src/gllo/fusion/conv_scale_fusion.h"
#include <memory>
#include "mindspore/lite/src/param_value_lite.h"
#include "mindspore/lite/schema/inner/model_generated.h"
#include "mindspore/lite/src/ir/primitive_t_value.h"
#include "mindspore/ccsrc/utils/utils.h"
#include "mindspore/lite/src/gllo/common/utils.h"
#include "src/param_value_lite.h"
#include "schema/inner/model_generated.h"
#include "src/ir/primitive_t_value.h"
#include "utils/utils.h"
#include "src/gllo/common/gllo_utils.h"
#include "include/errorcode.h"
#include "securec/include/securec.h"

View File

@ -17,7 +17,7 @@
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_
#include "mindspore/lite/src/gllo/fusion/conv_transform_fusion.h"
#include "src/gllo/fusion/conv_transform_fusion.h"
namespace mindspore::opt {
class ConvScaleFusion : public ConvTransformFusion {

View File

@ -14,13 +14,13 @@
* limitations under the License.
*/
#include "mindspore/lite/src/gllo/fusion/conv_transform_fusion.h"
#include "src/gllo/fusion/conv_transform_fusion.h"
#include <memory>
#include "mindspore/lite/src/param_value_lite.h"
#include "mindspore/lite/schema/inner/model_generated.h"
#include "mindspore/lite/src/ir/primitive_t_value.h"
#include "mindspore/ccsrc/utils/utils.h"
#include "mindspore/lite/src/gllo/common/utils.h"
#include "src/param_value_lite.h"
#include "schema/inner/model_generated.h"
#include "src/ir/primitive_t_value.h"
#include "utils/utils.h"
#include "src/gllo/common/gllo_utils.h"
#include "include/errorcode.h"
#include "securec/include/securec.h"
@ -78,6 +78,16 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co
GenNewConvTensor(func_graph, conv_node, kernel_nums, trans_scale, trans_bias);
delete[] trans_bias;
delete[] trans_scale;
auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveTValue>>(conv_node->input(0));
MS_ASSERT(primitiveT_value != nullptr);
auto type = primitiveT_value->GetPrimitiveT()->value.type;
if (type == schema::PrimitiveType_Conv2D) {
primitiveT_value->GetPrimitiveT()->value.AsConv2D()->hasBias = true;
} else if (type == schema::PrimitiveType_DepthwiseConv2D) {
primitiveT_value->GetPrimitiveT()->value.AsDepthwiseConv2D()->hasBias = true;
} else {
MS_LOG(EXCEPTION) << "Unsupported opType, " << type;
}
return pre_node;
}

View File

@ -18,7 +18,7 @@
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_
#include <string>
#include "mindspore/lite/src/gllo/common/optimizer.h"
#include "src/gllo/common/optimizer.h"
namespace mindspore::opt {
class ConvTransformFusion : public PatternProcessPass {

View File

@ -63,6 +63,8 @@ if(BUILD_CONVERTER)
${CCSRC_DIR}/pybind_api/export_flags.cc
${CCSRC_DIR}/utils/context/context_extends.cc
${CCSRC_DIR}/frontend/parallel/costmodel_context.cc
${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc
${CCSRC_DIR}/backend/optimizer/common/visit.cc
${CMAKE_CURRENT_SOURCE_DIR}/../src/common/graph_utils_extends.cc
)
else()
@ -202,12 +204,14 @@ if(BUILD_CONVERTER)
${LITE_DIR}/tools/converter/converter.cc
${LITE_DIR}/tools/converter/parser/onnx/onnx.pb.cc
${LITE_DIR}/test/st/converter_test.cc
${LITE_DIR}/test/ut/src/gllo/fusion/conv_activation_fusion_test.cc
${LITE_DIR}/test/ut/src/gllo/fusion/conv_biasadd_fusion_test.cc
${LITE_DIR}/test/ut/src/gllo/fusion/conv_bn_fusion_test.cc
${LITE_DIR}/test/ut/src/gllo/fusion/conv_scale_fusion_test.cc
${LITE_DIR}/src/gllo/common/node_pass.cc
${LITE_DIR}/src/gllo/common/optimizer.cc
${LITE_DIR}/src/gllo/common/pass_manager.cc
${LITE_DIR}/src/gllo/common/pattern_engine.cc
${LITE_DIR}/src/gllo/common/visit.cc
${LITE_DIR}/src/gllo/common/utils.cc
${LITE_DIR}/src/gllo/common/gllo_utils.cc
${LITE_DIR}/src/gllo/fusion/conv_biasadd_fusion.cc
${LITE_DIR}/src/gllo/fusion/conv_activation_fusion.cc
${LITE_DIR}/src/gllo/fusion/conv_transform_fusion.cc

View File

@ -0,0 +1,184 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "schema/inner/model_generated.h"
#include "include/model.h"
#include "common/common_test.h"
#include "include/lite_session.h"
#include "include/context.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/anf_transform.h"
#include "src/common/anf_exporter/anf_exporter.h"
namespace mindspore {
class ConvActivationFusionTest : public mindspore::Common {
public:
ConvActivationFusionTest() = default;
};
using MetaGraphTptr = std::shared_ptr<schema::MetaGraphT>;
using CNodeTptr = std::unique_ptr<schema::CNodeT>;
namespace {
CNodeTptr BuildConv2D() {
auto convNode = std::make_unique<schema::CNodeT>();
convNode->inputIndex = {0, 1};
convNode->outputIndex = {2};
convNode->primitive = std::make_unique<schema::PrimitiveT>();
convNode->primitive->value.type = schema::PrimitiveType_Conv2D;
auto prim1 = new schema::Conv2DT;
prim1->padMode = schema::PadMode_SAME;
prim1->format = schema::Format_NHWC;
prim1->strideH = 1;
prim1->strideW = 1;
prim1->kernelH = 3;
prim1->kernelW = 3;
prim1->dilateH = 1;
prim1->dilateW = 1;
prim1->channelOut = 3;
convNode->primitive->value.value = prim1;
convNode->name = "Conv2D";
return convNode;
}
CNodeTptr BuildDepthwiseConv2D() {
auto convNode = std::make_unique<schema::CNodeT>();
convNode->inputIndex = {0, 1};
convNode->outputIndex = {2};
convNode->primitive = std::make_unique<schema::PrimitiveT>();
convNode->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
auto prim1 = new schema::DepthwiseConv2DT;
prim1->padMode = schema::PadMode_SAME;
prim1->format = schema::Format_NHWC;
prim1->strideH = 1;
prim1->strideW = 1;
prim1->kernelH = 3;
prim1->kernelW = 3;
prim1->dilateH = 1;
prim1->dilateW = 1;
prim1->channelIn = 1;
prim1->channelMultiplier = 3;
convNode->primitive->value.value = prim1;
convNode->name = "Conv2D";
return convNode;
}
MetaGraphTptr BuildGraph(schema::PrimitiveType conv_type,
schema::ActivationType activation_type) {
auto meta_graph = std::make_shared<schema::MetaGraphT>();
meta_graph->name = "graph";
// conv node
CNodeTptr convNode;
if (conv_type == schema::PrimitiveType_Conv2D) {
convNode = BuildConv2D();
} else {
convNode = BuildDepthwiseConv2D();
}
meta_graph->nodes.emplace_back(std::move(convNode));
// relu node
auto next_node = std::make_unique<schema::CNodeT>();
next_node->inputIndex = {2};
next_node->outputIndex = {3};
next_node->primitive = std::make_unique<schema::PrimitiveT>();
next_node->primitive->value.type = schema::PrimitiveType_Activation;
auto prim2 = new schema::ActivationT;
prim2->type = activation_type;
next_node->primitive->value.value = prim2;
next_node->name = "activation";
meta_graph->nodes.emplace_back(std::move(next_node));
meta_graph->inputIndex = {0};
meta_graph->outputIndex = {3};
// input 0: data
auto input0 = std::make_unique<schema::TensorT>();
input0->nodeType = schema::NodeType::NodeType_ValueNode;
input0->format = schema::Format_NHWC;
input0->dataType = TypeId::kNumberTypeFloat32;
input0->dims = {1, 5, 5, 3};
input0->offset = -1;
meta_graph->allTensors.emplace_back(std::move(input0));
// input 1: weight
auto input1 = std::make_unique<schema::TensorT>();
input1->nodeType = schema::NodeType::NodeType_ValueNode;
input1->format = schema::Format_KHWC;
input1->dataType = TypeId::kNumberTypeFloat32;
input1->dims = {8, 3, 3, 3};
input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3);
meta_graph->allTensors.emplace_back(std::move(input1));
// conv output
auto conv_output = std::make_unique<schema::TensorT>();
conv_output->nodeType = schema::NodeType::NodeType_Parameter;
conv_output->format = schema::Format_NHWC;
conv_output->dataType = TypeId::kNumberTypeFloat32;
conv_output->dims = {1, 5, 5, 8};
meta_graph->allTensors.emplace_back(std::move(conv_output));
// final output
auto output = std::make_unique<schema::TensorT>();
output->nodeType = schema::NodeType::NodeType_Parameter;
output->format = schema::Format_NHWC;
output->dataType = TypeId::kNumberTypeFloat32;
output->dims = {1, 5, 5, 8};
meta_graph->allTensors.emplace_back(std::move(output));
return meta_graph;
}
} // namespace
TEST_F(ConvActivationFusionTest, TestConvReluNode) {
auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2D, schema::ActivationType_RELU);
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
for (auto &cnode : new_meta_graph->nodes) {
ASSERT_EQ(cnode->primitive->value.AsConv2D()->activationType, schema::ActivationType_RELU);
}
}
TEST_F(ConvActivationFusionTest, TestConvRelu6Node) {
auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2D, schema::ActivationType_RELU6);
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
for (auto &cnode : new_meta_graph->nodes) {
ASSERT_EQ(cnode->primitive->value.AsConv2D()->activationType, schema::ActivationType_RELU6);
}
}
TEST_F(ConvActivationFusionTest, TestBadCase_ConvRelu) {
auto meta_graph = BuildGraph(schema::PrimitiveType_DepthwiseConv2D, schema::ActivationType_LEAKY_RELU);
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 2);
for (auto &cnode : new_meta_graph->nodes) {
if (cnode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) {
ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->activationType, schema::ActivationType_NO_ACTIVATION);
}
}
}
} // namespace mindspore

View File

@ -0,0 +1,194 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "schema/inner/model_generated.h"
#include "include/model.h"
#include "common/common_test.h"
#include "include/lite_session.h"
#include "include/context.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/anf_transform.h"
#include "src/common/anf_exporter/anf_exporter.h"
namespace mindspore {
class ConvBiasAddFusionTest : public mindspore::Common {
public:
ConvBiasAddFusionTest() = default;
};
using MetaGraphTptr = std::shared_ptr<schema::MetaGraphT>;
using CNodeTptr = std::unique_ptr<schema::CNodeT>;
namespace {
CNodeTptr BuildConv2D() {
auto convNode = std::make_unique<schema::CNodeT>();
convNode->inputIndex = {0, 1};
convNode->outputIndex = {2};
convNode->primitive = std::make_unique<schema::PrimitiveT>();
convNode->primitive->value.type = schema::PrimitiveType_Conv2D;
auto prim1 = new schema::Conv2DT;
prim1->padMode = schema::PadMode_SAME;
prim1->format = schema::Format_NHWC;
prim1->strideH = 1;
prim1->strideW = 1;
prim1->kernelH = 3;
prim1->kernelW = 3;
prim1->dilateH = 1;
prim1->dilateW = 1;
prim1->channelOut = 3;
convNode->primitive->value.value = prim1;
convNode->name = "Conv2D";
return convNode;
}
CNodeTptr BuildDepthwiseConv2D() {
auto convNode = std::make_unique<schema::CNodeT>();
convNode->inputIndex = {0, 1};
convNode->outputIndex = {2};
convNode->primitive = std::make_unique<schema::PrimitiveT>();
convNode->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
auto prim1 = new schema::DepthwiseConv2DT;
prim1->padMode = schema::PadMode_SAME;
prim1->format = schema::Format_NHWC;
prim1->strideH = 1;
prim1->strideW = 1;
prim1->kernelH = 3;
prim1->kernelW = 3;
prim1->dilateH = 1;
prim1->dilateW = 1;
prim1->channelIn = 1;
prim1->channelMultiplier = 3;
convNode->primitive->value.value = prim1;
convNode->name = "Conv2D";
return convNode;
}
MetaGraphTptr BuildGraph(schema::PrimitiveType conv_type,
schema::PrimitiveType add_type) {
auto meta_graph = std::make_shared<schema::MetaGraphT>();
meta_graph->name = "graph";
// conv node
CNodeTptr convNode;
if (conv_type == schema::PrimitiveType_Conv2D) {
convNode = BuildConv2D();
} else {
convNode = BuildDepthwiseConv2D();
}
meta_graph->nodes.emplace_back(std::move(convNode));
// biasadd node
auto biasadd_node = std::make_unique<schema::CNodeT>();
biasadd_node->inputIndex = {2, 3};
biasadd_node->outputIndex = {4};
biasadd_node->primitive = std::make_unique<schema::PrimitiveT>();
biasadd_node->primitive->value.type = add_type;
auto prim2 = new schema::BiasAddT;
biasadd_node->primitive->value.value = prim2;
biasadd_node->name = "BiasAdd";
meta_graph->nodes.emplace_back(std::move(biasadd_node));
meta_graph->inputIndex = {0};
meta_graph->outputIndex = {4};
// input 0: data
auto input0 = std::make_unique<schema::TensorT>();
input0->nodeType = schema::NodeType::NodeType_ValueNode;
input0->format = schema::Format_NHWC;
input0->dataType = TypeId::kNumberTypeFloat32;
input0->dims = {1, 5, 5, 3};
input0->offset = -1;
meta_graph->allTensors.emplace_back(std::move(input0));
// input 1: weight
auto input1 = std::make_unique<schema::TensorT>();
input1->nodeType = schema::NodeType::NodeType_ValueNode;
input1->format = schema::Format_KHWC;
input1->dataType = TypeId::kNumberTypeFloat32;
input1->dims = {8, 3, 3, 3};
input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3);
meta_graph->allTensors.emplace_back(std::move(input1));
// conv output
auto conv_output = std::make_unique<schema::TensorT>();
conv_output->nodeType = schema::NodeType::NodeType_Parameter;
conv_output->format = schema::Format_NHWC;
conv_output->dataType = TypeId::kNumberTypeFloat32;
conv_output->dims = {1, 5, 5, 8};
meta_graph->allTensors.emplace_back(std::move(conv_output));
// input2: bias
auto input2 = std::make_unique<schema::TensorT>();
input2->nodeType = schema::NodeType::NodeType_ValueNode;
input2->format = schema::Format_NHWC;
input2->dataType = TypeId::kNumberTypeFloat32;
input2->dims = {1, 5, 5, 8};
input2->data.resize(sizeof(float) * 8 * 5 * 5);
meta_graph->allTensors.emplace_back(std::move(input2));
// final output
auto output = std::make_unique<schema::TensorT>();
output->nodeType = schema::NodeType::NodeType_Parameter;
output->format = schema::Format_NHWC;
output->dataType = TypeId::kNumberTypeFloat32;
output->dims = {1, 5, 5, 8};
meta_graph->allTensors.emplace_back(std::move(output));
return meta_graph;
}
} // namespace
TEST_F(ConvBiasAddFusionTest, TestConvAddNode) {
auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2D, schema::PrimitiveType_BiasAdd);
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
for (auto &cnode : new_meta_graph->nodes) {
ASSERT_EQ(cnode->primitive->value.AsConv2D()->hasBias, true);
}
MS_LOG(INFO) << "Passed";
}
TEST_F(ConvBiasAddFusionTest, TestDeptiwiseConvAddNode) {
auto meta_graph = BuildGraph(schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add);
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
for (auto &cnode : new_meta_graph->nodes) {
ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->hasBias, true);
}
}
TEST_F(ConvBiasAddFusionTest, TestBadCase_ConvAdd) {
auto meta_graph = BuildGraph(schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_MatMul);
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 2);
for (auto &cnode : new_meta_graph->nodes) {
if (cnode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) {
ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->hasBias, false);
}
}
}
} // namespace mindspore

View File

@ -0,0 +1,296 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "schema/inner/model_generated.h"
#include "include/model.h"
#include "common/common_test.h"
#include "include/lite_session.h"
#include "include/context.h"
#include "include/errorcode.h"
#include "mindspore/core/utils/log_adapter.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/anf_transform.h"
#include "src/common/anf_exporter/anf_exporter.h"
namespace mindspore {
class ConvBNFusionTest : public mindspore::Common {
public:
ConvBNFusionTest() = default;
};
using MetaGraphTptr = std::shared_ptr<schema::MetaGraphT>;
using CNodeTptr = std::unique_ptr<schema::CNodeT>;
namespace {
CNodeTptr BuildConv2D() {
auto convNode = std::make_unique<schema::CNodeT>();
convNode->inputIndex = {0, 1};
convNode->outputIndex = {2};
convNode->primitive = std::make_unique<schema::PrimitiveT>();
convNode->primitive->value.type = schema::PrimitiveType_Conv2D;
auto prim1 = new schema::Conv2DT;
prim1->padMode = schema::PadMode_SAME;
prim1->format = schema::Format_NHWC;
prim1->strideH = 1;
prim1->strideW = 1;
prim1->kernelH = 3;
prim1->kernelW = 3;
prim1->dilateH = 1;
prim1->dilateW = 1;
prim1->channelOut = 3;
convNode->primitive->value.value = prim1;
convNode->name = "Conv2D";
return convNode;
}
CNodeTptr BuildDepthwiseConv2D() {
auto convNode = std::make_unique<schema::CNodeT>();
convNode->inputIndex = {0, 1, 2};
convNode->outputIndex = {3};
convNode->primitive = std::make_unique<schema::PrimitiveT>();
convNode->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
auto prim1 = new schema::DepthwiseConv2DT;
prim1->padMode = schema::PadMode_SAME;
prim1->format = schema::Format_NHWC;
prim1->strideH = 1;
prim1->strideW = 1;
prim1->kernelH = 3;
prim1->kernelW = 3;
prim1->dilateH = 1;
prim1->dilateW = 1;
prim1->channelIn = 1;
prim1->channelMultiplier = 3;
convNode->primitive->value.value = prim1;
convNode->name = "Conv2D";
return convNode;
}
// caffe bn op has 3 inputs
MetaGraphTptr BuildCaffeGraph(schema::PrimitiveType conv_type) {
auto meta_graph = std::make_shared<schema::MetaGraphT>();
meta_graph->name = "graph";
// conv node
CNodeTptr convNode;
if (conv_type == schema::PrimitiveType_Conv2D) {
convNode = BuildConv2D();
} else {
convNode = BuildDepthwiseConv2D();
}
meta_graph->nodes.emplace_back(std::move(convNode));
// bn_node
auto bn_node = std::make_unique<schema::CNodeT>();
bn_node->inputIndex = {2, 3, 4};
bn_node->outputIndex = {5};
bn_node->primitive = std::make_unique<schema::PrimitiveT>();
bn_node->primitive->value.type = schema::PrimitiveType_CaffeBatchNorm;
auto prim2 = new schema::CaffeBatchNormT;
bn_node->primitive->value.value = prim2;
bn_node->name = "bn";
meta_graph->nodes.emplace_back(std::move(bn_node));
// input 0: data
auto input0 = std::make_unique<schema::TensorT>();
input0->nodeType = schema::NodeType::NodeType_ValueNode;
input0->format = schema::Format_NHWC;
input0->dataType = TypeId::kNumberTypeFloat32;
input0->dims = {1, 5, 5, 3};
input0->offset = -1;
meta_graph->allTensors.emplace_back(std::move(input0));
// input 1: weight
auto input1 = std::make_unique<schema::TensorT>();
input1->nodeType = schema::NodeType::NodeType_ValueNode;
input1->format = schema::Format_KHWC;
input1->dataType = TypeId::kNumberTypeFloat32;
input1->dims = {8, 3, 3, 3};
input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3);
meta_graph->allTensors.emplace_back(std::move(input1));
// conv output
auto conv_output = std::make_unique<schema::TensorT>();
conv_output->nodeType = schema::NodeType::NodeType_Parameter;
conv_output->format = schema::Format_NHWC;
conv_output->dataType = TypeId::kNumberTypeFloat32;
conv_output->dims = {1, 5, 5, 8};
meta_graph->allTensors.emplace_back(std::move(conv_output));
// caffe bn : mean
auto input2 = std::make_unique<schema::TensorT>();
input2->nodeType = schema::NodeType::NodeType_ValueNode;
input2->format = schema::Format_NHWC;
input2->dataType = TypeId::kNumberTypeFloat32;
input2->dims = {1, 5, 5, 8};
input2->data.resize(sizeof(float) * 8 * 5 * 5);
meta_graph->allTensors.emplace_back(std::move(input2));
// caffe bn : var
auto input3 = std::make_unique<schema::TensorT>();
input3->nodeType = schema::NodeType::NodeType_ValueNode;
input3->format = schema::Format_NHWC;
input3->dataType = TypeId::kNumberTypeFloat32;
input3->dims = {1, 5, 5, 8};
input3->data.resize(sizeof(float) * 8 * 5 * 5);
meta_graph->allTensors.emplace_back(std::move(input3));
// final bn output
auto output = std::make_unique<schema::TensorT>();
output->nodeType = schema::NodeType::NodeType_Parameter;
output->format = schema::Format_NHWC;
output->dataType = TypeId::kNumberTypeFloat32;
output->dims = {1, 5, 5, 8};
meta_graph->allTensors.emplace_back(std::move(output));
meta_graph->inputIndex = {0};
meta_graph->outputIndex = {5};
return meta_graph;
}
// tf bn op has 4 inputs
MetaGraphTptr BuildTFGraph(schema::PrimitiveType conv_type) {
auto meta_graph = std::make_shared<schema::MetaGraphT>();
meta_graph->name = "graph";
// conv node
CNodeTptr convNode;
if (conv_type == schema::PrimitiveType_Conv2D) {
convNode = BuildConv2D();
} else {
convNode = BuildDepthwiseConv2D();
}
meta_graph->nodes.emplace_back(std::move(convNode));
// bn_node
auto bn_node = std::make_unique<schema::CNodeT>();
bn_node->inputIndex = {3, 4, 5, 6, 7};
bn_node->outputIndex = {8};
bn_node->primitive = std::make_unique<schema::PrimitiveT>();
bn_node->primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
auto prim2 = new schema::FusedBatchNormT;
bn_node->primitive->value.value = prim2;
bn_node->name = "bn";
meta_graph->nodes.emplace_back(std::move(bn_node));
// input 0: data
auto input0 = std::make_unique<schema::TensorT>();
input0->nodeType = schema::NodeType::NodeType_ValueNode;
input0->format = schema::Format_NHWC;
input0->dataType = TypeId::kNumberTypeFloat32;
input0->dims = {1, 5, 5, 3};
input0->offset = -1;
meta_graph->allTensors.emplace_back(std::move(input0));
// input 1: conv_bias
auto input11 = std::make_unique<schema::TensorT>();
input11->nodeType = schema::NodeType::NodeType_ValueNode;
input11->format = schema::Format_KHWC;
input11->dataType = TypeId::kNumberTypeFloat32;
input11->dims = {8, 3, 3, 3};
input11->data.resize(sizeof(float) * 8 * 3 * 3 * 3);
meta_graph->allTensors.emplace_back(std::move(input11));
// input 1: weight
auto input1 = std::make_unique<schema::TensorT>();
input1->nodeType = schema::NodeType::NodeType_ValueNode;
input1->format = schema::Format_KHWC;
input1->dataType = TypeId::kNumberTypeFloat32;
input1->dims = {8, 3, 3, 3};
input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3);
meta_graph->allTensors.emplace_back(std::move(input1));
// conv output
auto conv_output = std::make_unique<schema::TensorT>();
conv_output->nodeType = schema::NodeType::NodeType_Parameter;
conv_output->format = schema::Format_NHWC;
conv_output->dataType = TypeId::kNumberTypeFloat32;
conv_output->dims = {1, 5, 5, 8};
meta_graph->allTensors.emplace_back(std::move(conv_output));
// tflite bn : scale
auto input2 = std::make_unique<schema::TensorT>();
input2->nodeType = schema::NodeType::NodeType_ValueNode;
input2->format = schema::Format_NHWC;
input2->dataType = TypeId::kNumberTypeFloat32;
input2->dims = {1, 5, 5, 8};
input2->data.resize(sizeof(float) * 8 * 5 * 5);
meta_graph->allTensors.emplace_back(std::move(input2));
// tflite bn : bias
auto input3 = std::make_unique<schema::TensorT>();
input3->nodeType = schema::NodeType::NodeType_ValueNode;
input3->format = schema::Format_NHWC;
input3->dataType = TypeId::kNumberTypeFloat32;
input3->dims = {1, 5, 5, 8};
input3->data.resize(sizeof(float) * 8 * 5 * 5);
meta_graph->allTensors.emplace_back(std::move(input3));
// tflite bn : mean
auto input4 = std::make_unique<schema::TensorT>();
input4->nodeType = schema::NodeType::NodeType_ValueNode;
input4->format = schema::Format_NHWC;
input4->dataType = TypeId::kNumberTypeFloat32;
input4->dims = {1, 5, 5, 8};
input4->data.resize(sizeof(float) * 8 * 5 * 5);
meta_graph->allTensors.emplace_back(std::move(input4));
// tflite bn : var
auto input5 = std::make_unique<schema::TensorT>();
input5->nodeType = schema::NodeType::NodeType_ValueNode;
input5->format = schema::Format_NHWC;
input5->dataType = TypeId::kNumberTypeFloat32;
input5->dims = {1, 5, 5, 8};
input5->data.resize(sizeof(float) * 8 * 5 * 5);
meta_graph->allTensors.emplace_back(std::move(input5));
// final output
auto output = std::make_unique<schema::TensorT>();
output->nodeType = schema::NodeType::NodeType_Parameter;
output->format = schema::Format_NHWC;
output->dataType = TypeId::kNumberTypeFloat32;
output->dims = {1, 5, 5, 8};
meta_graph->allTensors.emplace_back(std::move(output));
meta_graph->inputIndex = {0};
meta_graph->outputIndex = {8};
return meta_graph;
}
} // namespace
TEST_F(ConvBNFusionTest, TestConvAddNode) {
auto meta_graph = BuildCaffeGraph(schema::PrimitiveType_Conv2D);
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
for (auto &cnode : new_meta_graph->nodes) {
ASSERT_EQ(cnode->primitive->value.AsConv2D()->hasBias, true);
}
}
TEST_F(ConvBNFusionTest, TestDeptiwiseConvAddNode) {
auto meta_graph = BuildTFGraph(schema::PrimitiveType_DepthwiseConv2D);
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
for (auto &cnode : new_meta_graph->nodes) {
ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->hasBias, true);
}
}
} // namespace mindspore

View File

@ -0,0 +1,221 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "schema/inner/model_generated.h"
#include "include/model.h"
#include "common/common_test.h"
#include "include/lite_session.h"
#include "include/context.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/anf_transform.h"
#include "src/common/anf_exporter/anf_exporter.h"
namespace mindspore {
class ConvScaleFusionTest : public mindspore::Common {
public:
ConvScaleFusionTest() = default;
};
using MetaGraphTptr = std::shared_ptr<schema::MetaGraphT>;
using CNodeTptr = std::unique_ptr<schema::CNodeT>;
namespace {
// conv has 2 inputs
CNodeTptr BuildConv2D(int with_bias_flag) {
auto convNode = std::make_unique<schema::CNodeT>();
if (with_bias_flag) {
convNode->inputIndex = {0, 1, 2};
convNode->outputIndex = {3};
} else {
convNode->inputIndex = {0, 1};
convNode->outputIndex = {2};
}
convNode->primitive = std::make_unique<schema::PrimitiveT>();
convNode->primitive->value.type = schema::PrimitiveType_Conv2D;
auto prim1 = new schema::Conv2DT;
prim1->padMode = schema::PadMode_SAME;
prim1->format = schema::Format_NHWC;
prim1->strideH = 1;
prim1->strideW = 1;
prim1->kernelH = 3;
prim1->kernelW = 3;
prim1->dilateH = 1;
prim1->dilateW = 1;
prim1->channelOut = 3;
convNode->primitive->value.value = prim1;
convNode->name = "Conv2D";
return convNode;
}
// conv2d has 3 inputs
CNodeTptr BuildDepthwiseConv2D(int with_bias_flag) {
auto convNode = std::make_unique<schema::CNodeT>();
if (with_bias_flag) {
convNode->inputIndex = {0, 1, 2};
convNode->outputIndex = {3};
} else {
convNode->inputIndex = {0, 1};
convNode->outputIndex = {2};
}
convNode->primitive = std::make_unique<schema::PrimitiveT>();
convNode->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
auto prim1 = new schema::DepthwiseConv2DT;
prim1->padMode = schema::PadMode_SAME;
prim1->format = schema::Format_NHWC;
prim1->strideH = 1;
prim1->strideW = 1;
prim1->kernelH = 3;
prim1->kernelW = 3;
prim1->dilateH = 1;
prim1->dilateW = 1;
prim1->channelIn = 1;
prim1->channelMultiplier = 3;
convNode->primitive->value.value = prim1;
convNode->name = "Conv2D";
return convNode;
}
MetaGraphTptr BuildGraph(schema::PrimitiveType conv_type, bool conv_with_bias) {
auto meta_graph = std::make_shared<schema::MetaGraphT>();
meta_graph->name = "graph";
// conv node
CNodeTptr convNode;
if (conv_type == schema::PrimitiveType_Conv2D) {
convNode = BuildConv2D(conv_with_bias);
} else {
convNode = BuildDepthwiseConv2D(conv_with_bias);
}
meta_graph->nodes.emplace_back(std::move(convNode));
// scale_node weight bias
auto scale_node = std::make_unique<schema::CNodeT>();
if (conv_with_bias) {
scale_node->inputIndex = {3, 4, 5};
scale_node->outputIndex = {6};
} else {
scale_node->inputIndex = {2, 3, 4};
scale_node->outputIndex = {5};
}
scale_node->primitive = std::make_unique<schema::PrimitiveT>();
scale_node->primitive->value.type = schema::PrimitiveType_Scale;
auto prim2 = new schema::ScaleT;
scale_node->primitive->value.value = prim2;
scale_node->name = "scale";
meta_graph->nodes.emplace_back(std::move(scale_node));
// input 0: data
auto input0 = std::make_unique<schema::TensorT>();
input0->nodeType = schema::NodeType::NodeType_ValueNode;
input0->format = schema::Format_NHWC;
input0->dataType = TypeId::kNumberTypeFloat32;
input0->dims = {1, 5, 5, 3};
input0->offset = -1;
meta_graph->allTensors.emplace_back(std::move(input0));
// input 1: weight
auto input1 = std::make_unique<schema::TensorT>();
input1->nodeType = schema::NodeType::NodeType_ValueNode;
input1->format = schema::Format_KHWC;
input1->dataType = TypeId::kNumberTypeFloat32;
input1->dims = {8, 3, 3, 3};
input1->data.resize(sizeof(float) * 8 * 3 * 3 * 3);
meta_graph->allTensors.emplace_back(std::move(input1));
if (conv_with_bias) {
// input 00: bias
auto input00 = std::make_unique<schema::TensorT>();
input00->nodeType = schema::NodeType::NodeType_ValueNode;
input00->format = schema::Format_NHWC;
input00->dataType = TypeId::kNumberTypeFloat32;
input00->dims = {1, 5, 5, 3};
input00->offset = -1;
meta_graph->allTensors.emplace_back(std::move(input00));
}
// conv output
auto conv_output = std::make_unique<schema::TensorT>();
conv_output->nodeType = schema::NodeType::NodeType_Parameter;
conv_output->format = schema::Format_NHWC;
conv_output->dataType = TypeId::kNumberTypeFloat32;
conv_output->dims = {1, 5, 5, 8};
meta_graph->allTensors.emplace_back(std::move(conv_output));
// scale weight input
auto input2 = std::make_unique<schema::TensorT>();
input2->nodeType = schema::NodeType::NodeType_ValueNode;
input2->format = schema::Format_NHWC;
input2->dataType = TypeId::kNumberTypeFloat32;
input2->dims = {1, 5, 5, 8};
input2->data.resize(sizeof(float) * 8 * 5 * 5);
meta_graph->allTensors.emplace_back(std::move(input2));
// scale bias input
auto input3 = std::make_unique<schema::TensorT>();
input3->nodeType = schema::NodeType::NodeType_ValueNode;
input3->format = schema::Format_NHWC;
input3->dataType = TypeId::kNumberTypeFloat32;
input3->dims = {1, 5, 5, 8};
input3->data.resize(sizeof(float) * 8 * 5 * 5);
meta_graph->allTensors.emplace_back(std::move(input3));
// final scale output
auto output = std::make_unique<schema::TensorT>();
output->nodeType = schema::NodeType::NodeType_Parameter;
output->format = schema::Format_NHWC;
output->dataType = TypeId::kNumberTypeFloat32;
output->dims = {1, 5, 5, 8};
meta_graph->allTensors.emplace_back(std::move(output));
if (conv_with_bias) {
meta_graph->inputIndex = {0};
meta_graph->outputIndex = {6};
} else {
meta_graph->inputIndex = {0};
meta_graph->outputIndex = {5};
}
return meta_graph;
}
} // namespace
TEST_F(ConvScaleFusionTest, TestConvScaleNode) {
auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2D, true);
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
for (auto &cnode : new_meta_graph->nodes) {
ASSERT_EQ(cnode->primitive->value.AsConv2D()->hasBias, true);
}
}
TEST_F(ConvScaleFusionTest, TestDeptiwiseConvScaleNode) {
auto meta_graph = BuildGraph(schema::PrimitiveType_DepthwiseConv2D, false);
auto func_graph = lite::ModelParser::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
for (auto &cnode : new_meta_graph->nodes) {
ASSERT_EQ(cnode->primitive->value.AsDepthwiseConv2D()->hasBias, true);
ASSERT_EQ(cnode->inputIndex.size(), 3);
}
}
} // namespace mindspore

View File

@ -49,6 +49,8 @@ set(ANF_SRC
${CCSRC_DIR}/pybind_api/export_flags.cc
${CCSRC_DIR}/utils/context/context_extends.cc
${CCSRC_DIR}/frontend/parallel/costmodel_context.cc
${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc
${CCSRC_DIR}/backend/optimizer/common/visit.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/graph_utils_extends.cc
)
@ -75,9 +77,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/node_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/optimizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/pass_manager.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/pattern_engine.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/visit.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/common/gllo_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_biasadd_fusion.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_activation_fusion.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/gllo/fusion/conv_transform_fusion.cc

View File

@ -90,7 +90,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
return nullptr;
}
// auto newGraph = anfTransform->Transform(graph);
graph = anfTransform->Transform(graph);
CreateQuantizer(graph, flag);
if (mQuantizer != nullptr) {

View File

@ -100,20 +100,20 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
// }
// fusion
{
Optimizer fusionOptimizer;
fusionOptimizer.AddPass(new (std::nothrow) ConvBiasAddFusionPass());
fusionOptimizer.AddPass(new (std::nothrow) ConvBNFusionPass());
fusionOptimizer.AddPass(new (std::nothrow) ConvScaleFusionPass());
fusionOptimizer.AddPass(new (std::nothrow) ConvReluFusionPass());
fusionOptimizer.AddPass(new (std::nothrow) ConvRelu6FusionPass());
fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
status = fusionOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed";
return status;
}
}
// {
// Optimizer fusionOptimizer;
// fusionOptimizer.AddPass(new (std::nothrow) ConvBiasAddFusionPass());
// fusionOptimizer.AddPass(new (std::nothrow) ConvBNFusionPass());
// fusionOptimizer.AddPass(new (std::nothrow) ConvScaleFusionPass());
// fusionOptimizer.AddPass(new (std::nothrow) ConvReluFusionPass());
// fusionOptimizer.AddPass(new (std::nothrow) ConvRelu6FusionPass());
// fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
// status = fusionOptimizer.Run(graphDefT);
// if (status != RET_OK && status != RET_NO_CHANGE) {
// MS_LOG(ERROR) << "Run fusionOptimizer graphPasses Failed";
// return status;
// }
// }
// weight format trans
if (ctx.formatTrans) {